|
@@ -84,7 +84,7 @@ class BalancedRemoteExpert(nn.Module):
|
|
|
with self.expert_balancer.use_another_expert(1) as chosen_expert:
|
|
|
self._expert_info = chosen_expert.info
|
|
|
except BaseException as e:
|
|
|
- logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
|
|
|
+ logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
|
|
|
return self._expert_info
|
|
|
|
|
|
|
|
@@ -122,7 +122,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
break
|
|
|
except BaseException as e:
|
|
|
- logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
|
|
|
+ logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
return tuple(deserialized_outputs)
|
|
@@ -144,6 +144,6 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
|
|
|
break
|
|
|
except BaseException as e:
|
|
|
- logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
|
|
|
+ logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)
|