Преглед изворни кода

Refactor _{forward,backward}_stream()

Aleksandr Borzunov пре 2 година
родитељ
комит
cd829fde92
1 измењених фајлова са 43 додато и 50 уклоњено
  1. 43 50
      src/client/remote_forward_backward.py

+ 43 - 50
src/client/remote_forward_backward.py

@@ -10,12 +10,54 @@ from hivemind.compression.serialization import deserialize_tensor_stream, deseri
 from hivemind.p2p import StubBase
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.asyncio import iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import ModuleUID, RPCInfo
 
 
+async def _forward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def _backward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def _forward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    parts = (
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        for tensor in serialized_tensors
+        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+    )
+    outputs = await stub.rpc_forward_stream(iter_as_aiter(parts))
+    return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
+
+
+async def _backward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    parts = (
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        for tensor in serialized_tensors
+        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+    )
+    grad_inputs = await stub.rpc_backward_stream(iter_as_aiter(parts))
+    return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
+
+
 async def run_remote_forward(
     uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
 ) -> Tuple[torch.Tensor, ...]:
@@ -64,46 +106,6 @@ async def run_remote_forward(
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
-async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
-
-    outputs = await stub.rpc_forward_stream(
-        amap_in_executor(
-            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
-            iter_as_aiter(split),
-        ),
-    )
-
-    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
-    return await deserialize_tensor_stream(tensors_stream)
-
-
-async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
-    )
-    return [deserialize_torch_tensor(t) for t in outputs.tensors]
-
-
-async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
-
-    grad_inputs = await stub.rpc_backward_stream(
-        amap_in_executor(
-            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
-            iter_as_aiter(split),
-        ),
-    )
-    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
-    return await deserialize_tensor_stream(tensors_stream)
-
-
 async def run_remote_backward(
     uid: ModuleUID,
     stub: StubBase,
@@ -145,12 +147,3 @@ async def run_remote_backward(
         deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
 
     return deserialized_grad_inputs
-
-
-async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
-    )
-    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]