justheuristic 3 лет назад
Родитель
Сommit
7ce7cd7a97
4 измененных файлов с 117 добавлено и 36 удалено
  1. 2 2
      cli/run_server.py
  2. 1 3
      src/server/backend.py
  3. 9 2
      src/server/handler.py
  4. 105 29
      src/server/server.py

+ 2 - 2
cli/run_server.py

@@ -7,7 +7,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.server.server import BloomServer
+from src.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -63,7 +63,7 @@ def main():
     compression_type = args.pop("compression")
     compression = getattr(CompressionType, compression_type)
 
-    server = BloomServer.create(**args, start=True, compression=compression)
+    server = Server.create(**args, start=True, compression=compression)
 
     try:
         server.join()

+ 1 - 3
src/server/backend.py

@@ -16,10 +16,8 @@ from src.server.cache import MemoryCache
 # - ensure that TaskPool for inference is NOT batched
 # - ensure that optimizer/scheduler is not created
 
-HARDCODCED_LENGTH = 2048
 
-
-class BloomBlockBackend(ExpertBackend):
+class TransformerBlockBackend(ExpertBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
     def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
         object().__init__()  # to bypass super.__init__

+ 9 - 2
src/server/handler.py

@@ -1,6 +1,6 @@
-from typing import AsyncIterator
+from typing import AsyncIterator, Dict
 
-from hivemind import P2PContext
+from hivemind import P2PContext, DHT
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 
@@ -8,6 +8,13 @@ from hivemind.proto import runtime_pb2
 class BloomConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
 
+    def __init__(self, dht: DHT, experts: Dict[str, BloomBackend]):
+        super().__init__()
+        self.dht, self.experts = dht, experts
+        self._p2p: Optional[P2P] = None
+
+        self.ready = MPFuture()
+
     async def rpc_forward_incremental(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:

+ 105 - 29
src/server/server.py

@@ -1,43 +1,75 @@
+from __future__ import annotations
 import threading
 from typing import Optional, Dict, Union, Sequence
 
 import torch
-from hivemind import Server, DHT
+from hivemind import DHT, BatchTensorDescriptor
 from hivemind.moe.server.dht_handler import DHTHandlerThread
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+import multiprocessing as mp
 
 from src import DistributedBloomConfig
 from src.bloom.block import BloomBlock
 from src.server.cache import MemoryCache
-from src.server.backend import BloomBlockBackend
+from src.server.backend import TransformerBlockBackend
 from src.server.handler import BloomConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class BloomServer(Server):
+class Server(threading.Thread):
     """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
     def __init__(
-            self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
-            device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
-            cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
+            self, dht: DHT, module_backends: Dict[str, TransformerBlockBackend], *,
+            device: torch.device, num_connection_handlers: int = 8,
+            update_period: float = 30, expiration: Optional[float] = None,
+            start: bool, **kwargs
     ):
         threading.Thread.__init__(self)
-        self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
-
         self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
-        self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [
+            BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+        ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
-        self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
+        self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
         if start:
             self.run_in_background(await_ready=True)
 
+    def run(self):
+        """
+        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Serving {len(self.module_backends)} blocks:")
+        for expert_name, backend in self.module_backends.items():
+            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.module_backends:
+            self.dht_handler_thread.start()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.result()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
     # noinspection PyMethodOverriding
     @classmethod
     def create(
@@ -69,40 +101,84 @@ class BloomServer(Server):
 
         num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        if isinstance(block_config, str):
-            block_config = DistributedBloomConfig
+        block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
 
         # initialize modules
-        module_backends = {}
-        for i in range(len(module_backends)):
+        blocks = {}
+        for i in range(num_blocks):
             module_uid = f"dummy_block.{i}"
-            block = BloomBlock(block_config, layer_number=i)
-            #TODO run the actual model
-
-            module_backends[module_uid] = BloomBlockBackend(
-                name=expert_uid,
-                expert=block,
-                args_schema=args_schema,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
+            HARDCODCED_LENGTH = 2048
+
+            blocks[module_uid] = TransformerBlockBackend(
+                module_uid,
+                BloomBlock(block_config, layer_number=i),
+                args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
+                kwargs_schema={},
+                outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )
 
-        if checkpoint_dir is not None:
-            load_experts(experts, checkpoint_dir)
-
         return cls(
             dht,
-            experts,
+            blocks,
             cache_size_bytes=cache_size_bytes,
             num_connection_handlers=num_handlers,
             device=device,
-            checkpoint_dir=checkpoint_dir,
             stats_report_interval=stats_report_interval,
             update_period=update_period,
             expiration=expiration,
             start=start,
         )
 
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts Server in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
+
+        for process in self.conn_handlers:
+            process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.module_backends:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        self.dht.shutdown()
+        self.dht.join()
+
+        logger.debug(f"Shutting down runtime")
+
+        self.runtime.shutdown()
+        logger.info("Server shutdown succesfully")
+