|
@@ -134,7 +134,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
]
|
|
|
|
|
|
outputs = cls.run_coroutine(
|
|
|
- stub.rpc_forward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
|
|
|
+ stub.rpc_forward(as_aiter(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
|
|
|
)
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
@@ -153,7 +153,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
]
|
|
|
|
|
|
grad_inputs = cls.run_coroutine(
|
|
|
- ctx.stub.rpc_backward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
|
|
|
+ ctx.stub.rpc_backward(as_aiter(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
|
|
|
)
|
|
|
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|