|
@@ -173,7 +173,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- return [deserialize_torch_tensor(t) for t in outputs]
|
|
|
|
|
|
|
+ return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
|
@once_differentiable
|
|
@once_differentiable
|
|
@@ -226,4 +226,4 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
|
|
|
|
+ return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|