justheuristic 3 lat temu
rodzic
commit
1cca611c9f
3 zmienionych plików z 12 dodań i 14 usunięć
  1. 1 1
      src/server/backend.py
  2. 5 6
      src/server/handler.py
  3. 6 7
      src/server/server.py

+ 1 - 1
src/server/backend.py

@@ -17,7 +17,7 @@ from src.server.cache import MemoryCache
 # - ensure that optimizer/scheduler is not created
 
 
-class TransformerBlockBackend(ExpertBackend):
+class BloomBlockBackend(ExpertBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
     def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
         object().__init__()  # to bypass super.__init__

+ 5 - 6
src/server/handler.py

@@ -4,20 +4,19 @@ from hivemind import P2PContext, DHT
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 
+from src.bloom.block import BloomBlock
+
 
 class BloomConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
 
-    def __init__(self, dht: DHT, experts: Dict[str, BloomBackend]):
-        super().__init__()
-        self.dht, self.experts = dht, experts
-        self._p2p: Optional[P2P] = None
-
-        self.ready = MPFuture()
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
 
     async def rpc_forward_incremental(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        # note: you may use self.experts[uid].memory_cache!
         # encode expert_uid as @model_name[starting_layer:finishing_layer]
         # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
         # - receive and maintain a handle for attention cache here

+ 6 - 7
src/server/server.py

@@ -14,7 +14,7 @@ import multiprocessing as mp
 from src import DistributedBloomConfig
 from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
-from src.server.backend import TransformerBlockBackend
+from src.server.backend import BloomBlockBackend
 from src.server.handler import BloomConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
@@ -24,16 +24,14 @@ logger = get_logger(__file__)
 class Server(threading.Thread):
     """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
     def __init__(
-            self, dht: DHT, module_backends: Dict[str, TransformerBlockBackend], *,
+            self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
             device: torch.device, num_connection_handlers: int = 8,
             update_period: float = 30, expiration: Optional[float] = None,
             start: bool, **kwargs
     ):
         threading.Thread.__init__(self)
         self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
-        self.conn_handlers = [
-            BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
-        ]
+        self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
@@ -102,16 +100,17 @@ class Server(threading.Thread):
         num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
-
+        memory_cache = MemoryCache(device, cache_size_bytes)
         # initialize modules
         blocks = {}
         for i in range(num_blocks):
             module_uid = f"dummy_block.{i}"
             HARDCODCED_LENGTH = 2048
 
-            blocks[module_uid] = TransformerBlockBackend(
+            blocks[module_uid] = BloomBlockBackend(
                 module_uid,
                 BloomBlock(block_config, layer_number=i),
+                memory_cache=memory_cache,
                 args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
                 kwargs_schema={},
                 outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),