Browse Source

fix missprint bug

Pavel Samygin 3 years ago
parent
commit
fd967093f7
1 changed files with 2 additions and 2 deletions
  1. 2 2
      hivemind/moe/client/expert.py

+ 2 - 2
hivemind/moe/client/expert.py

@@ -173,7 +173,7 @@ class _RemoteModuleCall(torch.autograd.Function):
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
 
-        return [deserialize_torch_tensor(t) for t in outputs]
+        return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
     @classmethod
     @once_differentiable
@@ -226,4 +226,4 @@ class _RemoteModuleCall(torch.autograd.Function):
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
 
-        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]