Explorar o código

basic multi-step inference session

justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
a00ec56ade
Modificáronse 4 ficheiros con 135 adicións e 41 borrados
  1. 116 27
      src/client/remote_block.py
  2. 1 1
      src/server/backend.py
  3. 5 5
      src/server/cache.py
  4. 13 8
      src/server/handler.py

+ 116 - 27
src/client/remote_block.py

@@ -1,33 +1,109 @@
-from concurrent.futures import Future
+from __future__ import annotations
+import asyncio
 from functools import partial
-from typing import List, Optional, Union, Sequence
+from typing import List, Optional, Union, Sequence, AsyncIterator, Dict, Any
 
 import torch
-from hivemind.moe.client import RemoteExpert
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertUID
-from hivemind.moe.server.dht_handler import _get_experts
-from hivemind.p2p import StubBase, P2P
-from hivemind.proto.runtime_pb2 import ExpertInfo
-from hivemind.dht import DHT
-from hivemind.utils import MPFuture, DHTExpiration
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertUID, ExpertInfo as RemoteModuleInfo
+from hivemind.p2p import P2P, PeerID, StubBase
+from hivemind.proto import runtime_pb2
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.utils import MPFuture, DHTExpiration, get_dht_time, as_aiter, anext, nested_flatten
+from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
 
 from src.server.handler import TransformerConnectionHandler
 
 
 class RemoteTransformerBlock(RemoteExpert):
-    """A class that interacts with a specific remote server for forward/backward or inference"""
-
-    def __init__(self, info: ExpertInfo, p2p: P2P):
-        super().__init__(info, p2p)
-        # self._config = config
-        # self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
-        # self._active_stream: Optional[RemoteTransformerStream] = None
+    """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
     @property
     def stub(self) -> StubBase:
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
+    def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
+        """Initialize a new inference session with the specified remote server"""
+        return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
+
+
+class RemoteTransformerBlockInferenceSession:
+    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
+    def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+        self.uid, self.info = uid, info
+        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
+        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
+        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.closed = False
+
+    @classmethod
+    async def _create(
+            cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
+    ) -> RemoteTransformerBlockInferenceSession:
+        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        inputs_queue = asyncio.Queue()
+        outputs_stream = await remote_module.stub.rpc_inference(
+            cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
+        )
+        return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
+
+    @staticmethod
+    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+        while True:
+            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            yield next_input_message
+            if not next_input_message.uid and not next_input_message.tensors:
+                break  # this message means "done sending"
+
+    def step(self, new_hidden_states: torch.Tensor):
+        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+        if self.closed:
+            raise Exception("Session is closed, cannot perform step")
+        # serialize inputs and put them into the queue
+        inputs = (new_hidden_states,)
+        outputs_serialized = RemoteExpertWorker.run_coroutine(self._step(
+            runtime_pb2.ExpertRequest(uid=self.uid, tensors=[
+                serialize_torch_tensor(tensor, proto.compression)
+                for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
+            ])
+        ))
+        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        return outputs[0]
+
+    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+        await self._inputs_queue.put(inputs_serialized)
+        return await anext(self._outputs_stream)
+
+    def close(self):
+        """Finish a given inference session, close the underlying connection"""
+        if self._outputs_stream is None:
+            return  # already closed
+        RemoteExpertWorker.run_coroutine(self._aclose_stream())
+        self._outputs_stream = self._inputs_queue = None
+        self.closed = True
+
+    async def _aclose_stream(self):
+        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+        if self._outputs_stream is None:
+            return  # already closed
+        await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+        try:
+            await anext(self._outputs_stream)
+        except StopAsyncIteration:
+            pass
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        assert not self.closed
+        return self
+
+    def __exit__(self, *exc_details):
+        self.close()
 
 
 def get_remote_module(
@@ -40,25 +116,38 @@ def get_remote_module(
     :returns: a list of [RemoteTransformerBlock if found else None]
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-    return create_remote_module(result, dht, return_future)
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time),
+        return_future)
 
-
-def create_remote_module(
-    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
     if return_future:
-
         async def _unpack(infos_future: MPFuture, dht: DHT):
             p2p = await dht.replicate_p2p()
-            return _create_remote_experts(await infos_future, p2p)
+            return _create_remote_modules_from_infos(await infos_future, p2p)
 
         return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
     p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    return _create_remote_experts(infos, p2p)
+    return _create_remote_modules_from_infos(infos, p2p)
+
+
+async def _get_remote_module_infos(
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
+) -> List[Optional[RemoteModuleInfo]]:
+    if expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    experts: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        server_peer_id = found[uid]
+        if server_peer_id is not None and isinstance(server_peer_id.value, str):
+            experts[i] = RemoteModuleInfo(uid, PeerID.from_base58(server_peer_id.value))
+    return experts
 
 
-def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
+def _create_remote_modules_from_infos(infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
+                                      ) -> List[Optional[RemoteTransformerBlock]]:
     experts: List[Optional[RemoteTransformerBlock]] = []
     for info in infos:
         if info is not None:

+ 1 - 1
src/server/backend.py

@@ -30,7 +30,7 @@ class TransformerBackend(ModuleBackend):
         attention_cache_handle = int(cache_metadata[0, 0].item())
         current_sequence_length = int(cache_metadata[0, 1].item())
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
-            print('METADATA:', cache_metadata, "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
+            print('METADATA:', cache_metadata, "CACHE", cache.mean(), "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
             cache[...] += 1
             return (inputs[0] + cache.flatten()[0],)
 

+ 5 - 5
src/server/cache.py

@@ -4,11 +4,11 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 
 TODO In future, one could modify cache to implement, among other things,
-- in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
--- note: this can be done using mp.Condtion
-- allocate cache as one contigous buffer to avoid fragmentation
-- quantize cached values using bitsandbytes
-- LRU offloading from gpu to ram
+-
+--
+-
+-
+-
 
 """
 import contextlib

+ 13 - 8
src/server/handler.py

@@ -22,7 +22,10 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         try:
+            print("OPENED RPC_INFERENCE")
             request = await anext(requests)
+            if not request.uid:
+                raise RuntimeError("User did not provide any uids.")
             backend = self.module_backends[request.uid]
             assert isinstance(backend, TransformerBackend)
 
@@ -33,13 +36,15 @@ class TransformerConnectionHandler(ConnectionHandler):
             current_sequence_length = 0
 
             async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
-                inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
-                print("INPUTS:", inputs)
-                assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
-                cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
-                outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
-                yield runtime_pb2.ExpertResponse(tensors=outputs)
-
-                current_sequence_length += inputs[1].shape[1]
+                while request.uid or request.tensors:  # iterate while user is willing to supply tensors
+                    inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
+                    print("INPUTS:", inputs)
+                    assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
+                    cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
+                    outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
+                    yield runtime_pb2.ExpertResponse(tensors=outputs)
+
+                    current_sequence_length += inputs[1].shape[1]
+                    request = await(anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")