|
@@ -123,7 +123,6 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
break
|
|
|
except BaseException as e:
|
|
|
logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
|
|
|
- raise
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
return tuple(deserialized_outputs)
|
|
@@ -146,6 +145,5 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
break
|
|
|
except BaseException as e:
|
|
|
logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
|
|
|
- raise
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|