浏览代码

switch to local code

justheuristic 2 年之前
父节点
当前提交
964dc32c3f
共有 2 个文件被更改,包括 90 次插入9 次删除
  1. 78 0
      src/client/remote_forward_backward.py
  2. 12 9
      src/client/sequential_autograd.py

+ 78 - 0
src/client/remote_forward_backward.py

@@ -0,0 +1,78 @@
+"""
+Utility functions that call RPC forward or backward on a single remote server
+"""
+from typing import Iterable, List, Sequence
+
+import torch
+from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.streaming import split_for_streaming
+
+
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def remote_backward(
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    """Call rpc_backward (unary or stream) on a single remote server, return grads w.r.t. arguments"""
+    size = 0
+    for t in inputs_and_grads:
+        size += t.element_size() * t.nelement()
+        if size > MAX_UNARY_PAYLOAD_SIZE:
+            return await _backward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _backward_unary(uid, serialized_tensors, stub)
+
+
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def remote_forward(
+    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    """Call rpc_forward (unary or stream) on a single remote server, return block outputs"""
+    size = 0
+    for t in inputs:
+        size += t.element_size() * t.nelement()
+        if size > MAX_UNARY_PAYLOAD_SIZE:
+            return await _forward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _forward_unary(uid, serialized_tensors, stub)

+ 12 - 9
src/client/sequential_autograd.py

@@ -1,14 +1,17 @@
+"""
+A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
+"""
 import asyncio
 import logging
 from typing import List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import serialize_torch_tensor
-from hivemind.moe.client.expert import expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 
+from src.client.remote_forward_backward import remote_backward, remote_forward
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
@@ -17,11 +20,11 @@ from src.utils.misc import DUMMY, is_dummy
 MAX_TOKENS_IN_BATCH = 1024
 
 
-async def run_expert_forward(
+async def run_remote_forward(
     uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
-    Serializes input tensors and calls "expert_forward".
+    Serializes input tensors and calls "rpc_forward" on a remote server.
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
@@ -55,12 +58,12 @@ async def run_expert_forward(
         )
     )
 
-    deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
+    deserialized_outputs = await remote_forward(uid, inputs, serialized_tensors, stub)
     flat_outputs = tuple(deserialized_outputs)
     return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
 
 
-async def run_expert_backward(
+async def run_remote_backward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
@@ -69,7 +72,7 @@ async def run_expert_backward(
     *extra_tensors: torch.Tensor,
 ) -> Sequence[torch.Tensor]:
     """
-    Serializes grad outputs and calls "expert_backward".
+    Serializes grad outputs and calls "rpc_backward" on a remote server.
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
@@ -93,7 +96,7 @@ async def run_expert_backward(
         )
     )
 
-    deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
+    deserialized_grad_inputs = await remote_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
     return deserialized_grad_inputs
 
 
@@ -130,7 +133,7 @@ async def sequential_forward(
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
-                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
+                (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -171,7 +174,7 @@ async def sequential_backward(
             span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                grad_outputs, *span_grad_prompts = await run_expert_backward(
+                grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
                 )
                 grad_outputs = [grad_outputs]