瀏覽代碼

fix missprint bug

Pavel Samygin 3 年之前
父節點
當前提交
fd967093f7
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      hivemind/moe/client/expert.py

+ 2 - 2
hivemind/moe/client/expert.py

@@ -173,7 +173,7 @@ class _RemoteModuleCall(torch.autograd.Function):
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
 
-        return [deserialize_torch_tensor(t) for t in outputs]
+        return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
     @classmethod
     @once_differentiable
@@ -226,4 +226,4 @@ class _RemoteModuleCall(torch.autograd.Function):
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
 
-        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]