justheuristic 4 жил өмнө
parent
commit
a51b0db279

+ 3 - 3
hivemind/moe/client/balanced_expert.py

@@ -84,7 +84,7 @@ class BalancedRemoteExpert(nn.Module):
                 with self.expert_balancer.use_another_expert(1) as chosen_expert:
                     self._expert_info = chosen_expert.info
             except BaseException as e:
-                logger.error(f"Tried to get expert info from {chosen_expert} but caught {e}")
+                logger.error(f"Tried to get expert info from {chosen_expert} but caught {repr(e)}")
         return self._expert_info
 
 
@@ -122,7 +122,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
                     outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
                 break
             except BaseException as e:
-                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {e}")
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
         return tuple(deserialized_outputs)
@@ -144,6 +144,6 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
                     grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
                 break
             except BaseException as e:
-                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {e}")
+                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)