|
@@ -115,10 +115,10 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
serialize_torch_tensor(inp, proto.compression)
|
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
]
|
|
|
- forward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
|
|
|
while True:
|
|
|
try:
|
|
|
with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
+ forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
break
|
|
|
except BaseException as e:
|
|
@@ -138,10 +138,10 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
serialize_torch_tensor(tensor, proto.compression)
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
]
|
|
|
- backward_request = runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)
|
|
|
while True:
|
|
|
try:
|
|
|
with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
|
|
|
+ backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
|
|
|
break
|
|
|
except BaseException as e:
|