Denis Mazur 4 rokov pred
rodič
commit
05ff01f12c
1 zmenil súbory, kde vykonal 2 pridanie a 2 odobranie
  1. 2 2
      hivemind/moe/client/expert.py

+ 2 - 2
hivemind/moe/client/expert.py

@@ -134,7 +134,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         outputs = cls.run_coroutine(
-            stub.rpc_forward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
+            stub.rpc_forward(as_aiter(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
         )
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
@@ -153,7 +153,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         grad_inputs = cls.run_coroutine(
-            ctx.stub.rpc_backward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
+            ctx.stub.rpc_backward(as_aiter(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
         )
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]