瀏覽代碼

pre-check type

justheuristic 3 年之前
父節點
當前提交
62d7fde8af
共有 2 個文件被更改,包括 33 次插入4 次删除
  1. 29 0
      src/client/inference_chain.py
  2. 4 4
      src/server/handler.py

+ 29 - 0
src/client/inference_chain.py

@@ -0,0 +1,29 @@
+from typing import Sequence
+from collections import defaultdict
+
+import torch
+from hivemind import DHT
+from torch import nn
+
+from src import DistributedBloomConfig
+
+MAX_LENGTH = 128  #TODO un-hardcode
+
+
+class RemoteInferenceChain(nn.Module):
+    """An auxiliary class that manages distributed inference in a chain of one or more remote transformer modules"""
+
+    def __init__(self, dht: DHT, config: DistributedBloomConfig, block_names: Sequence[str]):
+        super().__init__()
+        self.dht = dht
+        self.config, self.block_names = config, block_names
+        self.block_caches = {name: torch.zeros(1, MAX_LENGTH, config.hidden_size) for name in block_names}
+        self.current_position = 0
+
+    def step(self, hidden_states: torch.Tensor):
+        pass
+
+# plan:
+# - run inference STUB from a jupyter notebook
+# - extend to run actual inference
+# - extend to run multiple layers at a time

+ 4 - 4
src/server/handler.py

@@ -1,7 +1,7 @@
 from typing import AsyncIterator, Dict
 
 import torch
-from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, ModuleBackend
+from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import anext
@@ -13,20 +13,20 @@ class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
 
     def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
-        super().__init__(dht, module_backends)
         for module_backend in module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
+        super().__init__(dht, module_backends)
 
     async def rpc_inference(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
 
         request = await anext(requests)
-        backend = self.experts[request.uid]
+        backend = self.module_backends[request.uid]
         assert isinstance(backend, TransformerBackend)
 
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
+        async with backend.memory_cache.allocate_cache(TensorDescriptor(size=(1,2,3), dtype=torch.float32)):
             outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
 
         yield runtime_pb2.ExpertResponse(tensors=outputs)