|
@@ -118,15 +118,15 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
]
|
|
|
while True:
|
|
|
- # try:
|
|
|
- with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
- forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
- outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
- break
|
|
|
- # except KeyboardInterrupt:
|
|
|
- # break
|
|
|
- # except BaseException as e:
|
|
|
- # logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
+ try:
|
|
|
+ with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
+ forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
+ outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
+ break
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ raise
|
|
|
+ except BaseException as e:
|
|
|
+ logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
return tuple(deserialized_outputs)
|
|
@@ -147,6 +147,8 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
grad_inputs = chosen_expert.stub.backward(backward_request, timeout=ctx.backward_timeout)
|
|
|
break
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ raise
|
|
|
except BaseException as e:
|
|
|
logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|