Bladeren bron

inference mode

justheuristic 3 jaren geleden
bovenliggende
commit
e3a7d5af30
2 gewijzigde bestanden met toevoegingen van 15 en 18 verwijderingen
  1. 9 4
      src/server/backend.py
  2. 6 14
      src/server/handler.py

+ 9 - 4
src/server/backend.py

@@ -1,8 +1,7 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Tuple
+from typing import Tuple, Sequence
 
 import torch
-from hivemind import BatchTensorDescriptor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.task_pool import TaskPool
 
@@ -14,7 +13,7 @@ class TransformerBlockBackend(ExpertBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
     def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
-        super().__init__(*args, **kwargs)  # to bypass super.__init__
+        super().__init__(*args, **kwargs)
         self.memory_cache = memory_cache
 
         for name, param in self.module.named_parameters():
@@ -22,6 +21,12 @@ class TransformerBlockBackend(ExpertBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-    def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
+
+        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
+
+    def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
         with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
             return inputs[0] * 2
+
+    def get_pools(self) -> Sequence[TaskPool]:
+        return self.forward_pool, self.backward_pool, self.inference_pool

+ 6 - 14
src/server/handler.py

@@ -15,24 +15,16 @@ class TransformerConnectionHandler(ConnectionHandler):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    async def rpc_forward_incremental(
+    async def rpc_inference(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
 
         request = await anext(requests)
-        expert = self.experts[request.uid]
-        assert isinstance(expert, TransformerBlockBackend)
+        backend = self.experts[request.uid]
+        assert isinstance(backend, TransformerBlockBackend)
 
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
-            outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+        async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
+            outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
 
-        return runtime_pb2.ExpertResponse(tensors=outputs)
-
-
-        # note: you may use self.experts[uid].memory_cache!
-        # encode expert_uid as @model_name[starting_layer:finishing_layer]
-        # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
-        # - receive and maintain a handle for attention cache here
-
-        raise NotImplementedError()
+        yield runtime_pb2.ExpertResponse(tensors=outputs)