Переглянути джерело

switch to hivemind-master

justheuristic 3 роки тому
батько
коміт
5a15c13ca7
5 змінених файлів з 18 додано та 17 видалено
  1. 1 1
      README.md
  2. 2 2
      src/client/remote_block.py
  3. 5 6
      src/server/backend.py
  4. 7 5
      src/server/handler.py
  5. 3 3
      src/server/server.py

+ 1 - 1
README.md

@@ -18,7 +18,7 @@ conda activate bloom-demo
 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
 pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
 pip install bitsandbytes-cuda113==0.26.0
-pip install https://github.com/learning-at-home/hivemind/archive/bc2cccfdb0d7c905a12ef6c3ad052a1250af9878.zip
+pip install https://github.com/learning-at-home/hivemind/archive/master.zip
 pip install https://github.com/huggingface/transformers/archive/224bde91caff4ccfd12277ab5e9bf97c61e22ee9.zip
 ```
 

+ 2 - 2
src/client/remote_block.py

@@ -8,8 +8,8 @@ from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.server.dht_handler import _get_experts
 from hivemind.p2p import StubBase, P2P
 from hivemind.proto.runtime_pb2 import ExpertInfo
-from hivemind.dht import DHTExpiration, DHT
-from hivemind.utils import MPFuture
+from hivemind.dht import DHT
+from hivemind.utils import MPFuture, DHTExpiration
 from src.server.handler import TransformerConnectionHandler
 
 

+ 5 - 6
src/server/backend.py

@@ -2,14 +2,13 @@
 from typing import Tuple, Sequence
 
 import torch
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 
-from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
 
 
-class TransformerBlockBackend(ExpertBackend):
+class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
     def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
@@ -21,12 +20,12 @@ class TransformerBlockBackend(ExpertBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-
         self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
 
     def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
-        with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
-            return inputs[0] * 2
+        with self.memory_cache.use_cache(attention_cache_handle) as cache:
+            cache[...] += 1
+            return inputs[0] + cache
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool

+ 7 - 5
src/server/handler.py

@@ -1,19 +1,21 @@
 from typing import AsyncIterator, Dict
 
 import torch
-from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
+from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, ModuleBackend
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import anext
 
-from src.server.backend import TransformerBlockBackend
+from src.server.backend import TransformerBackend
 
 
 class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
+    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
+        super().__init__(dht, module_backends)
+        for module_backend in module_backends.values():
+            assert isinstance(module_backend, TransformerBackend)
 
     async def rpc_inference(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -21,7 +23,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         request = await anext(requests)
         backend = self.experts[request.uid]
-        assert isinstance(backend, TransformerBlockBackend)
+        assert isinstance(backend, TransformerBackend)
 
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         async with backend.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):

+ 3 - 3
src/server/server.py

@@ -14,7 +14,7 @@ import multiprocessing as mp
 from src import DistributedBloomConfig
 from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
-from src.server.backend import TransformerBlockBackend
+from src.server.backend import TransformerBackend
 from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
@@ -27,7 +27,7 @@ class Server(threading.Thread):
     def __init__(
         self,
         dht: DHT,
-        module_backends: Dict[str, TransformerBlockBackend],
+        module_backends: Dict[str, TransformerBackend],
         *,
         device: torch.device,
         num_connection_handlers: int = 8,
@@ -118,7 +118,7 @@ class Server(threading.Thread):
             for param in block.parameters():
                 param.requires_grad = False
 
-            blocks[module_uid] = TransformerBlockBackend(
+            blocks[module_uid] = TransformerBackend(
                 module_uid,
                 block,
                 memory_cache=memory_cache,