|
@@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
from hivemind import serialize_torch_tensor
|
|
|
-from hivemind.moe.client.expert import expert_backward, expert_forward
|
|
|
+from hivemind.moe.client.expert import expert_backward, expert_forward, _forward_stream
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import StubBase
|
|
|
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
|
|
@@ -39,14 +39,14 @@ async def run_expert_forward(
|
|
|
forward_inputs = nested_flatten(forward_inputs)
|
|
|
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
|
|
|
|
|
- # TODO: figure out whether we should use run_in_executor here
|
|
|
- serialized_tensors = (
|
|
|
- serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ serialized_tensors = await asyncio.gather(*(
|
|
|
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
|
|
|
- )
|
|
|
+ ))
|
|
|
+
|
|
|
deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
|
|
|
flat_outputs = tuple(deserialized_outputs)
|
|
|
-
|
|
|
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
|
|
|
|
|
|
|
|
@@ -67,10 +67,12 @@ async def run_expert_backward(
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
|
|
|
|
|
|
- serialized_tensors = (
|
|
|
- serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ serialized_tensors = await asyncio.gather(*(
|
|
|
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
- )
|
|
|
+ ))
|
|
|
+
|
|
|
deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
|
|
|
return deserialized_grad_inputs
|
|
|
|