Selaa lähdekoodia

extract remote block inference

justheuristic 3 vuotta sitten
vanhempi
commit
6ed9ee2d3a

+ 2 - 1
src/client/__init__.py

@@ -1,4 +1,5 @@
-from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
+from src.client.remote_block import RemoteTransformerBlock
+from src.client.remote_block_inference import RemoteTransformerBlockInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential
 from src.client.sequence_manager import RemoteSequenceManager

+ 3 - 93
src/client/remote_block.py

@@ -1,20 +1,16 @@
 # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 from __future__ import annotations
 
-import asyncio
 import random
-from typing import Any, AsyncIterator, Dict, Optional
 
 import torch
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
+from hivemind.utils import get_logger, use_hivemind_log_handler
 
-from src.data_structures import RemoteModuleInfo, RPCInfo
-from src.dht_utils import ModuleUID
+from src import RemoteTransformerBlockInferenceSession
+from src.data_structures import RemoteModuleInfo
 from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
@@ -46,89 +42,3 @@ class RemoteTransformerBlock(RemoteExpert):
     def begin_inference_session(self):
         logger.warning("beging_inference_session was renamed to just inference_session")
         return self.inference_session()
-
-
-class RemoteTransformerBlockInferenceSession:
-    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
-
-    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
-        self.uid, self.rpc_info = uid, rpc_info
-        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
-        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
-        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
-        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self.stepped = False
-        self.closed = False
-
-    @classmethod
-    async def _create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
-    ) -> RemoteTransformerBlockInferenceSession:
-        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
-        inputs_queue = asyncio.Queue()
-        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
-        return cls(uid, rpc_info, inputs_queue, outputs_stream)
-
-    @staticmethod
-    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
-        while True:
-            next_input_message = await asyncio.wait_for(queue.get(), timeout)
-            yield next_input_message
-            if not next_input_message.uid and not next_input_message.tensors:
-                break  # this message means "done sending"
-
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
-        if self.closed:
-            raise Exception("Session is closed, cannot perform step")
-        # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
-        outputs_serialized = RemoteExpertWorker.run_coroutine(
-            self._step(
-                runtime_pb2.ExpertRequest(
-                    uid=self.uid,
-                    tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
-                    ],
-                )
-            )
-        )
-        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs[0]
-
-    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
-        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
-        await self._inputs_queue.put(inputs_serialized)
-        self.stepped = True
-        return await anext(self._outputs_stream)
-
-    def close(self):
-        """Finish a given inference session, close the underlying connection"""
-        if self._outputs_stream is None:
-            return  # already closed
-        RemoteExpertWorker.run_coroutine(self._aclose_stream())
-        self._outputs_stream = self._inputs_queue = None
-        self.closed = True
-
-    async def _aclose_stream(self):
-        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
-        if self._outputs_stream is None:
-            return  # already closed
-        if self.stepped:
-            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-            try:
-                await anext(self._outputs_stream)
-            except StopAsyncIteration:
-                pass
-
-    def __del__(self):
-        self.close()
-
-    def __enter__(self):
-        assert not self.closed
-        return self
-
-    def __exit__(self, *exc_details):
-        self.close()

+ 98 - 0
src/client/remote_block_inference.py

@@ -0,0 +1,98 @@
+from __future__ import annotations
+
+import asyncio
+from typing import AsyncIterator, Optional
+
+import torch
+from hivemind import serialize_torch_tensor, nested_flatten, deserialize_torch_tensor, anext
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.proto import runtime_pb2
+
+from src.data_structures import ModuleUID, RPCInfo
+
+
+class RemoteTransformerBlockInferenceSession:
+    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
+
+    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+        self.uid, self.rpc_info = uid, rpc_info
+        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
+        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
+        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.stepped = False
+        self.closed = False
+
+    @classmethod
+    async def _create(
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
+    ) -> RemoteTransformerBlockInferenceSession:
+        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        inputs_queue = asyncio.Queue()
+        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream)
+
+    @staticmethod
+    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+        while True:
+            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            yield next_input_message
+            if not next_input_message.uid and not next_input_message.tensors:
+                break  # this message means "done sending"
+
+    def step(self, new_hidden_states: torch.Tensor):
+        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+        if self.closed:
+            raise Exception("Session is closed, cannot perform step")
+        # serialize inputs and put them into the queue
+        inputs = (new_hidden_states,)
+        outputs_serialized = RemoteExpertWorker.run_coroutine(
+            self._step(
+                runtime_pb2.ExpertRequest(
+                    uid=self.uid,
+                    tensors=[
+                        serialize_torch_tensor(tensor, proto.compression)
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                    ],
+                )
+            )
+        )
+        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        return outputs[0]
+
+    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+        await self._inputs_queue.put(inputs_serialized)
+        self.stepped = True
+        return await anext(self._outputs_stream)
+
+    def close(self):
+        """Finish a given inference session, close the underlying connection"""
+        if self._outputs_stream is None:
+            return  # already closed
+        RemoteExpertWorker.run_coroutine(self._aclose_stream())
+        self._outputs_stream = self._inputs_queue = None
+        self.closed = True
+
+    async def _aclose_stream(self):
+        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+        if self._outputs_stream is None:
+            return  # already closed
+        if self.stepped:
+            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+            try:
+                await anext(self._outputs_stream)
+            except StopAsyncIteration:
+                pass
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        assert not self.closed
+        return self
+
+    def __exit__(self, *exc_details):
+        self.close()

+ 13 - 9
src/client/remote_sequential.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import contextlib
 import logging
 import random
-from typing import Optional, Union
+from typing import Optional, Union, List
 
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
@@ -13,8 +13,9 @@ from torch import nn
 
 import src
 from src.client.remote_block import RemoteTransformerBlock
+from src import RemoteTransformerBlockInferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import UID_DELIMITER
+from src.data_structures import UID_DELIMITER, RemoteSpanInfo
 from src.dht_utils import _create_remote_modules_from_infos
 
 use_hivemind_log_handler("in_root_logger")
@@ -113,23 +114,26 @@ class RemoteSequentialInferenceSession:
         self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
+        self.chosen_spans: List[RemoteSpanInfo] = []
         self.stack = contextlib.ExitStack()
-        self.active_sessions = []
+        self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
 
     def __enter__(self):
-        assert not self.closed
+        assert not self.closed and not self.chosen_spans
         self.stack.__enter__()
         # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
+        self.chosen_spans.extend(self.sequence_manager.make_sequence())
 
-        for chosen_span in self.sequence_manager.make_sequence():
+        for chosen_span in self.chosen_spans:
+            TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
             # TODO begin throwaway prototype code
             remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)
             _ = remote.info  # TODO fix
             span_uids = self.sequence_manager.block_uids[current_block : chosen_span.end]
             remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
-            self.active_sessions.append(remote.inference_session())
-            self.stack.enter_context(self.active_sessions[-1])
+            self.inference_sessions.append(remote.inference_session())
+            self.stack.enter_context(self.inference_sessions[-1])
             current_block = chosen_span.end
             # TODO end throwaway prototype code
 
@@ -137,7 +141,7 @@ class RemoteSequentialInferenceSession:
 
     def step(self, inputs: torch.Tensor):
         assert not self.closed
-        for session in self.active_sessions:
+        for session in self.inference_sessions:
             outputs = session.step(inputs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
@@ -147,7 +151,7 @@ class RemoteSequentialInferenceSession:
         """Finish a given inference session, close the underlying connection"""
         if not self.closed:
             self.stack.__exit__(*exc_details or (None, None, None))
-            self.active_sessions.clear()
+            self.inference_sessions.clear()
             self.closed = True
 
     def __exit__(self, *exc_details):