Browse Source

fix backward

Denis Mazur 4 years ago
parent
commit
97654aceb5
1 changed files with 1 additions and 1 deletions
  1. 1 1
      hivemind/moe/client/expert.py

+ 1 - 1
hivemind/moe/client/expert.py

@@ -157,7 +157,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         grad_inputs = cls.run_coroutine(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+            ctx.stub.rpc_backward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
         )
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]