justheuristic 2 年 前
コミット
208c09ef25
2 ファイル変更17 行追加8 行削除
  1. 8 3
      src/client/remote_forward_backward.py
  2. 9 5
      src/client/sequential_autograd.py

+ 8 - 3
src/client/remote_forward_backward.py

@@ -2,7 +2,7 @@
 Utility functions that call RPC forward or backward on a single remote server
 """
 import asyncio
-from typing import Iterable, List, Sequence, Tuple, Optional
+from typing import Iterable, List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
@@ -63,8 +63,13 @@ async def _backward_stream(
 
 
 async def run_remote_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float,
-        metadata: Optional[bytes] = None, **kwargs
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    *inputs: torch.Tensor,
+    timeout: float,
+    metadata: Optional[bytes] = None,
+    **kwargs,
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.

+ 9 - 5
src/client/sequential_autograd.py

@@ -72,11 +72,14 @@ async def sequential_forward(
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
-                metadata = sequence_manager.get_request_metadata(
-                    'rpc_forward', span_uids, *inputs_and_prompts)
+                metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
                 (outputs,) = await run_remote_forward(
-                    span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts,
-                    timeout=sequence_manager.timeout, metadata=metadata
+                    span_uids,
+                    stub,
+                    sequence_manager.rpc_info,
+                    *inputs_and_prompts,
+                    timeout=sequence_manager.timeout,
+                    metadata=metadata,
                 )
 
                 assert isinstance(outputs, torch.Tensor)
@@ -150,7 +153,8 @@ async def sequential_backward(
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 metadata = sequence_manager.get_request_metadata(
-                    'rpc_backward', span_uids, *inputs, *grad_outputs, peer_id=span.peer_id)
+                    "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
+                )
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
                     stub,