Browse Source

basic multi-step inference session

justheuristic 3 years ago
parent
commit
a00ec56ade
4 changed files with 135 additions and 41 deletions
  1. 116 27
      src/client/remote_block.py
  2. 1 1
      src/server/backend.py
  3. 5 5
      src/server/cache.py
  4. 13 8
      src/server/handler.py

+ 116 - 27
src/client/remote_block.py

@@ -1,33 +1,109 @@
-from concurrent.futures import Future
+from __future__ import annotations
+import asyncio
 from functools import partial
 from functools import partial
-from typing import List, Optional, Union, Sequence
+from typing import List, Optional, Union, Sequence, AsyncIterator, Dict, Any
 
 
 import torch
 import torch
-from hivemind.moe.client import RemoteExpert
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertUID
-from hivemind.moe.server.dht_handler import _get_experts
-from hivemind.p2p import StubBase, P2P
-from hivemind.proto.runtime_pb2 import ExpertInfo
-from hivemind.dht import DHT
-from hivemind.utils import MPFuture, DHTExpiration
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertUID, ExpertInfo as RemoteModuleInfo
+from hivemind.p2p import P2P, PeerID, StubBase
+from hivemind.proto import runtime_pb2
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.utils import MPFuture, DHTExpiration, get_dht_time, as_aiter, anext, nested_flatten
+from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
 
 
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
 
 
 
 
 class RemoteTransformerBlock(RemoteExpert):
 class RemoteTransformerBlock(RemoteExpert):
-    """A class that interacts with a specific remote server for forward/backward or inference"""
-
-    def __init__(self, info: ExpertInfo, p2p: P2P):
-        super().__init__(info, p2p)
-        # self._config = config
-        # self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
-        # self._active_stream: Optional[RemoteTransformerStream] = None
+    """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
 
     @property
     @property
     def stub(self) -> StubBase:
     def stub(self) -> StubBase:
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
 
+    def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
+        """Initialize a new inference session with the specified remote server"""
+        return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
+
+
+class RemoteTransformerBlockInferenceSession:
+    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
+    def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+        self.uid, self.info = uid, info
+        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
+        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
+        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.closed = False
+
+    @classmethod
+    async def _create(
+            cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
+    ) -> RemoteTransformerBlockInferenceSession:
+        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        inputs_queue = asyncio.Queue()
+        outputs_stream = await remote_module.stub.rpc_inference(
+            cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
+        )
+        return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
+
+    @staticmethod
+    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+        while True:
+            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            yield next_input_message
+            if not next_input_message.uid and not next_input_message.tensors:
+                break  # this message means "done sending"
+
+    def step(self, new_hidden_states: torch.Tensor):
+        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+        if self.closed:
+            raise Exception("Session is closed, cannot perform step")
+        # serialize inputs and put them into the queue
+        inputs = (new_hidden_states,)
+        outputs_serialized = RemoteExpertWorker.run_coroutine(self._step(
+            runtime_pb2.ExpertRequest(uid=self.uid, tensors=[
+                serialize_torch_tensor(tensor, proto.compression)
+                for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
+            ])
+        ))
+        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        return outputs[0]
+
+    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+        await self._inputs_queue.put(inputs_serialized)
+        return await anext(self._outputs_stream)
+
+    def close(self):
+        """Finish a given inference session, close the underlying connection"""
+        if self._outputs_stream is None:
+            return  # already closed
+        RemoteExpertWorker.run_coroutine(self._aclose_stream())
+        self._outputs_stream = self._inputs_queue = None
+        self.closed = True
+
+    async def _aclose_stream(self):
+        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+        if self._outputs_stream is None:
+            return  # already closed
+        await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+        try:
+            await anext(self._outputs_stream)
+        except StopAsyncIteration:
+            pass
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        assert not self.closed
+        return self
+
+    def __exit__(self, *exc_details):
+        self.close()
 
 
 
 
 def get_remote_module(
 def get_remote_module(
@@ -40,25 +116,38 @@ def get_remote_module(
     :returns: a list of [RemoteTransformerBlock if found else None]
     :returns: a list of [RemoteTransformerBlock if found else None]
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-    return create_remote_module(result, dht, return_future)
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time),
+        return_future)
 
 
-
-def create_remote_module(
-    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
     if return_future:
     if return_future:
-
         async def _unpack(infos_future: MPFuture, dht: DHT):
         async def _unpack(infos_future: MPFuture, dht: DHT):
             p2p = await dht.replicate_p2p()
             p2p = await dht.replicate_p2p()
-            return _create_remote_experts(await infos_future, p2p)
+            return _create_remote_modules_from_infos(await infos_future, p2p)
 
 
         return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
         return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
     p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
     p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    return _create_remote_experts(infos, p2p)
+    return _create_remote_modules_from_infos(infos, p2p)
+
+
+async def _get_remote_module_infos(
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
+) -> List[Optional[RemoteModuleInfo]]:
+    if expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    experts: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        server_peer_id = found[uid]
+        if server_peer_id is not None and isinstance(server_peer_id.value, str):
+            experts[i] = RemoteModuleInfo(uid, PeerID.from_base58(server_peer_id.value))
+    return experts
 
 
 
 
-def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
+def _create_remote_modules_from_infos(infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
+                                      ) -> List[Optional[RemoteTransformerBlock]]:
     experts: List[Optional[RemoteTransformerBlock]] = []
     experts: List[Optional[RemoteTransformerBlock]] = []
     for info in infos:
     for info in infos:
         if info is not None:
         if info is not None:

+ 1 - 1
src/server/backend.py

@@ -30,7 +30,7 @@ class TransformerBackend(ModuleBackend):
         attention_cache_handle = int(cache_metadata[0, 0].item())
         attention_cache_handle = int(cache_metadata[0, 0].item())
         current_sequence_length = int(cache_metadata[0, 1].item())
         current_sequence_length = int(cache_metadata[0, 1].item())
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
-            print('METADATA:', cache_metadata, "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
+            print('METADATA:', cache_metadata, "CACHE", cache.mean(), "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
             cache[...] += 1
             cache[...] += 1
             return (inputs[0] + cache.flatten()[0],)
             return (inputs[0] + cache.flatten()[0],)
 
 

+ 5 - 5
src/server/cache.py

@@ -4,11 +4,11 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 
 
 TODO In future, one could modify cache to implement, among other things,
 TODO In future, one could modify cache to implement, among other things,
-- in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
--- note: this can be done using mp.Condtion
-- allocate cache as one contigous buffer to avoid fragmentation
-- quantize cached values using bitsandbytes
-- LRU offloading from gpu to ram
+-
+--
+-
+-
+-
 
 
 """
 """
 import contextlib
 import contextlib

+ 13 - 8
src/server/handler.py

@@ -22,7 +22,10 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         try:
         try:
+            print("OPENED RPC_INFERENCE")
             request = await anext(requests)
             request = await anext(requests)
+            if not request.uid:
+                raise RuntimeError("User did not provide any uids.")
             backend = self.module_backends[request.uid]
             backend = self.module_backends[request.uid]
             assert isinstance(backend, TransformerBackend)
             assert isinstance(backend, TransformerBackend)
 
 
@@ -33,13 +36,15 @@ class TransformerConnectionHandler(ConnectionHandler):
             current_sequence_length = 0
             current_sequence_length = 0
 
 
             async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
             async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
-                inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
-                print("INPUTS:", inputs)
-                assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
-                cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
-                outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
-                yield runtime_pb2.ExpertResponse(tensors=outputs)
-
-                current_sequence_length += inputs[1].shape[1]
+                while request.uid or request.tensors:  # iterate while user is willing to supply tensors
+                    inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
+                    print("INPUTS:", inputs)
+                    assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
+                    cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
+                    outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
+                    yield runtime_pb2.ExpertResponse(tensors=outputs)
+
+                    current_sequence_length += inputs[1].shape[1]
+                    request = await(anext(requests))
         finally:
         finally:
             print("CLOSED RPC_INFERENCE")
             print("CLOSED RPC_INFERENCE")