Procházet zdrojové kódy

async serialization & cast to the scheme dtype

dbaranchuk před 3 roky
rodič
revize
c23034a468
1 změnil soubory, kde provedl 11 přidání a 9 odebrání
  1. 11 9
      src/client/sequential_autograd.py

+ 11 - 9
src/client/sequential_autograd.py

@@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import serialize_torch_tensor
-from hivemind.moe.client.expert import expert_backward, expert_forward
+from hivemind.moe.client.expert import expert_backward, expert_forward, _forward_stream
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
@@ -39,14 +39,14 @@ async def run_expert_forward(
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
-    # TODO: figure out whether we should use run_in_executor here
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, proto.compression)
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(*(
+        loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
         for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
-    )
+    ))
+
     deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
     flat_outputs = tuple(deserialized_outputs)
-
     return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -67,10 +67,12 @@ async def run_expert_backward(
     inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
 
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, proto.compression)
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(*(
+        loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
         for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-    )
+    ))
+    
     deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
     return deserialized_grad_inputs