Forráskód Böngészése

disentangle inference session from RemoteTransformerBlock

justheuristic 3 éve
szülő
commit
6df55d6bd9
1 módosított fájl, 7 hozzáadás és 7 törlés
  1. 7 7
      src/client/remote_block.py

+ 7 - 7
src/client/remote_block.py

@@ -13,7 +13,7 @@ from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
 
-from src.data_structures import RemoteModuleInfo
+from src.data_structures import RemoteModuleInfo, RPCInfo
 from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 
@@ -50,8 +50,8 @@ class RemoteTransformerBlock(RemoteExpert):
 class RemoteTransformerBlockInferenceSession:
     """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
 
-    def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
-        self.uid, self.info = uid, info
+    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+        self.uid, self.rpc_info = uid, rpc_info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@@ -61,14 +61,14 @@ class RemoteTransformerBlockInferenceSession:
 
     @classmethod
     async def _create(
-        cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
-        outputs_stream = await remote_module.stub.rpc_inference(
+        outputs_stream = await stub.rpc_inference(
             cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
         )
-        return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
@@ -90,7 +90,7 @@ class RemoteTransformerBlockInferenceSession:
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
                 )
             )