justheuristic 3 anni fa
parent
commit
628e168897

+ 2 - 2
src/client/remote_block.py

@@ -6,14 +6,14 @@ import random
 from typing import AsyncIterator, Optional
 
 import torch
-from hivemind import serialize_torch_tensor, nested_flatten, deserialize_torch_tensor, anext
+from hivemind import anext, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.utils import get_logger, use_hivemind_log_handler
 
-from src.data_structures import RemoteModuleInfo, ModuleUID, RPCInfo
+from src.data_structures import ModuleUID, RemoteModuleInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")

+ 9 - 7
src/client/remote_sequential.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import contextlib
 import logging
 import random
-from typing import Optional, Union, List
+from typing import List, Optional, Union
 
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
@@ -14,7 +14,7 @@ from torch import nn
 import src
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import UID_DELIMITER, RemoteSpanInfo, CHAIN_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, RemoteSpanInfo
 from src.dht_utils import _create_remote_modules_from_infos
 from src.server.handler import TransformerConnectionHandler
 
@@ -108,7 +108,7 @@ class RemoteSequential(nn.Module):
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float]=None):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
         self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
@@ -125,10 +125,12 @@ class RemoteSequentialInferenceSession:
 
         for chosen_span in self.chosen_spans:
             stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
-            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start: chosen_span.end])
-            inference_session = RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(
-                stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
-            ))
+            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
+            inference_session = RemoteExpertWorker.run_coroutine(
+                RemoteTransformerBlockInferenceSession._create(
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
+                )
+            )
             self.inference_sessions.append(inference_session)
             self.stack.enter_context(inference_session)
 

+ 1 - 1
src/client/sequence_manager.py

@@ -4,7 +4,7 @@ import random
 import threading
 from typing import List, Optional, Sequence, Tuple, Union
 
-from hivemind import DHT, DHTExpiration, P2P, MSGPackSerializer
+from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler