justheuristic vor 4 Jahren
Ursprung
Commit
6b9de865bc
1 geänderte Dateien mit 2 neuen und 2 gelöschten Zeilen
  1. 2 2
      hivemind/moe/client/balanced_expert.py

+ 2 - 2
hivemind/moe/client/balanced_expert.py

@@ -115,10 +115,10 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             serialize_torch_tensor(inp, proto.compression)
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
-        forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
         while True:
             try:
                 with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
                     outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
                 break
             except BaseException as e:
@@ -138,10 +138,10 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
-        backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
         while True:
             try:
                 with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
+                    backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
                     grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
                 break
             except BaseException as e: