|
@@ -126,7 +126,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
except KeyboardInterrupt:
|
|
|
raise
|
|
|
except BaseException as e:
|
|
|
- logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
+ logger.exception(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
return tuple(deserialized_outputs)
|
|
@@ -150,6 +150,6 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
except KeyboardInterrupt:
|
|
|
raise
|
|
|
except BaseException as e:
|
|
|
- logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
+ logger.exception(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|