|
@@ -125,8 +125,8 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
break
|
|
|
except KeyboardInterrupt:
|
|
|
raise
|
|
|
- except BaseException as e:
|
|
|
- logger.exception(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
+ except BaseException:
|
|
|
+ logger.exception(f"Tried to call forward for expert {chosen_expert}:")
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
return tuple(deserialized_outputs)
|
|
@@ -149,7 +149,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
break
|
|
|
except KeyboardInterrupt:
|
|
|
raise
|
|
|
- except BaseException as e:
|
|
|
- logger.exception(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
+ except BaseException:
|
|
|
+ logger.exception(f"Tried to call backward for expert {chosen_expert}:")
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|