|
@@ -133,11 +133,11 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
]
|
|
|
|
|
|
- async def func():
|
|
|
- stream = stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
- return await stream.__anext__()
|
|
|
-
|
|
|
- outputs = cls.run_coroutine(func())
|
|
|
+ outputs = cls.run_coroutine(
|
|
|
+ asingle(
|
|
|
+ stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
|
|
|
+ ),
|
|
|
+ )
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
|
|
@@ -154,11 +154,11 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
]
|
|
|
|
|
|
- async def func():
|
|
|
- stream = ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
- return await asingle(stream)
|
|
|
-
|
|
|
- grad_inputs = cls.run_coroutine(func())
|
|
|
+ grad_inputs = cls.run_coroutine(
|
|
|
+ asingle(
|
|
|
+ ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
|
|
|
+ ),
|
|
|
+ )
|
|
|
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|
|
|
return (DUMMY, None, None, None, *deserialized_grad_inputs)
|