浏览代码

black-isort

justheuristic 3 年之前
父节点
当前提交
3b16d6ffdb
共有 3 个文件被更改,包括 21 次插入14 次删除
  1. 2 9
      src/server/backend.py
  2. 16 2
      src/server/handler.py
  3. 3 3
      src/server/server.py

+ 2 - 9
src/server/backend.py

@@ -10,14 +10,7 @@ from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
 
 
-# TODO
-# BloomBackend serves a single layer
-# - ensure that parameters do not require grad!
-# - ensure that TaskPool for inference is NOT batched
-# - ensure that optimizer/scheduler is not created
-
-
-class BloomBlockBackend(ExpertBackend):
+class TransformerBlockBackend(ExpertBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
     def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
@@ -31,4 +24,4 @@ class BloomBlockBackend(ExpertBackend):
 
     def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
         with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
-            raise NotImplementedError("TODO")
+            return inputs[0] * 2

+ 16 - 2
src/server/handler.py

@@ -1,10 +1,12 @@
 from typing import AsyncIterator, Dict
 
-from hivemind import P2PContext, DHT
+import torch
+from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import anext
 
-from src.bloom.block import BloomBlock
+from src.server.backend import TransformerBlockBackend
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -16,6 +18,18 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_forward_incremental(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+
+        request = await anext(requests)
+        expert = self.experts[request.uid]
+        assert isinstance(expert, TransformerBlockBackend)
+
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
+            outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+
+        return runtime_pb2.ExpertResponse(tensors=outputs)
+
+
         # note: you may use self.experts[uid].memory_cache!
         # encode expert_uid as @model_name[starting_layer:finishing_layer]
         # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat

+ 3 - 3
src/server/server.py

@@ -14,7 +14,7 @@ import multiprocessing as mp
 from src import DistributedBloomConfig
 from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
-from src.server.backend import BloomBlockBackend
+from src.server.backend import TransformerBlockBackend
 from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
@@ -27,7 +27,7 @@ class Server(threading.Thread):
     def __init__(
         self,
         dht: DHT,
-        module_backends: Dict[str, BloomBlockBackend],
+        module_backends: Dict[str, TransformerBlockBackend],
         *,
         device: torch.device,
         num_connection_handlers: int = 8,
@@ -118,7 +118,7 @@ class Server(threading.Thread):
             for param in block.parameters():
                 param.requires_grad = False
 
-            blocks[module_uid] = BloomBlockBackend(
+            blocks[module_uid] = TransformerBlockBackend(
                 module_uid,
                 block,
                 memory_cache=memory_cache,