|
@@ -157,7 +157,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
]
|
|
|
|
|
|
grad_inputs = cls.run_coroutine(
|
|
|
- ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
|
|
|
+ ctx.stub.rpc_backward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
|
|
|
)
|
|
|
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|