5
0
Эх сурвалжийг харах

Extract ModuleContainer class from Server

Aleksandr Borzunov 3 жил өмнө
parent
commit
c8b85555f4
2 өөрчлөгдсөн 193 нэмэгдсэн , 95 устгасан
  1. 2 2
      cli/run_server.py
  2. 191 93
      src/server/server.py

+ 2 - 2
cli/run_server.py

@@ -43,7 +43,7 @@ def main():
                         help='Use this many threads to pass results/exceptions from Runtime to Pools')
     parser.add_argument('--inference_max_length', type=int, default=16384,
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
-    parser.add_argument('--cache_dir', type=str, default=None, 
+    parser.add_argument('--cache_dir', type=str, default=None,
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all blocks will use this device in torch notation; default: cuda if available else cpu')
@@ -104,7 +104,7 @@ def main():
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
-    server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
+    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
 
     try:
         server.join()

+ 191 - 93
src/server/server.py

@@ -4,7 +4,7 @@ import multiprocessing as mp
 import random
 import threading
 import time
-from typing import Dict, Optional, Sequence, Union
+from typing import Dict, Optional, List, Sequence, Union
 
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
@@ -29,76 +29,14 @@ logger = get_logger(__file__)
 
 
 class Server(threading.Thread):
-    """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
+    """
+    Runs Server, periodically checks that the network is balanced,
+    restarts the Server with other layers if the imbalance is significant
+    """
 
     def __init__(
         self,
-        dht: DHT,
-        module_backends: Dict[str, TransformerBackend],
-        *,
-        inference_max_length: int,
-        num_connection_handlers: int = 8,
-        throughput: float,
-        update_period: float = 30,
-        expiration: Optional[float] = None,
-        start: bool,
-        **kwargs,
-    ):
-        threading.Thread.__init__(self)
-        self.dht, self.module_backends = dht, module_backends
-        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
-        self.conn_handlers = [
-            TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
-            for _ in range(num_connection_handlers)
-        ]
-        self.runtime = Runtime(self.module_backends, **kwargs)
-        self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends,
-            dht,
-            throughput=throughput,
-            update_period=update_period,
-            expiration=expiration,
-            daemon=True,
-        )
-        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    def run(self):
-        """
-        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs Runtime (self.runtime) to process incoming requests.
-        """
-        logger.info(f"Serving {len(self.module_backends)} blocks:")
-        for block_name, backend in self.module_backends.items():
-            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
-            parameter_msg = f"{num_parameters} trainable parameters" if num_parameters else "frozen"
-            logger.info(f"{block_name}: {backend.module.__class__.__name__}, {parameter_msg}")
-
-        if not self.dht.is_alive():
-            self.dht.run_in_background(await_ready=True)
 
-        if self.module_backends:
-            self.dht_handler_thread.start()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.result()
-
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
-
-    # noinspection PyMethodOverriding
-    @classmethod
-    def create(
-        cls,
         prefix: Optional[str],
         converted_model_name_or_path: str,
         throughput: Union[float, str],
@@ -127,10 +65,26 @@ class Server(threading.Thread):
         *,
         start: bool,
         **kwargs,
-    ) -> Server:
+    ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
+
+        super().__init__()
+
+        self.converted_model_name_or_path = converted_model_name_or_path
+        self.num_handlers = num_handlers
+        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+        self.inference_max_length = inference_max_length
+        self.cache_dir = cache_dir
+        self.attn_cache_size = attn_cache_size
+        self.compression = compression
+        self.stats_report_interval, self.update_period = stats_report_interval, update_period
+        self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
+        self.use_auth_token = use_auth_token
+        self.load_in_8bit = load_in_8bit
+
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
+
         if prefix is None:
             prefix = converted_model_name_or_path
             assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
@@ -138,27 +92,37 @@ class Server(threading.Thread):
                 f"Please specify --prefix manually when starting a server"
             )
             logger.info(f"Automatic dht prefix: {prefix}")
+        self.prefix = prefix
+
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+        self.expiration = expiration
 
-        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
-        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        memory_cache = MemoryCache(device, attn_cache_size)
+        self.device = device
+
+        self.memory_cache = MemoryCache(device, attn_cache_size)
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
+        self.throughput = throughput
 
         if isinstance(torch_dtype, str):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        self.torch_dtype = torch_dtype
 
-        block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
+        self.block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path,
+            use_auth_token=use_auth_token,
+            revision=revision,
         )
 
         if block_indices is not None:
@@ -175,10 +139,148 @@ class Server(threading.Thread):
             time.sleep(random.random() * max_block_selection_delay)
 
             assert num_blocks is not None
-            uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
-            module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
+            uids = [f"{prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+            module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
             block_indices = choose_best_blocks(num_blocks, module_infos)
+        self.block_indices = block_indices
+
+        self.stop = threading.Event()
+        if start:
+            self.start()
+
+    def run(self):
+        self.module_container = ModuleContainer.create(
+            dht=self.dht,
+            prefix=self.prefix,
+            converted_model_name_or_path=self.converted_model_name_or_path,
+            block_config=self.block_config,
+            memory_cache=self.memory_cache,
+            throughput=self.throughput,
+            block_indices=self.block_indices,
+            num_handlers=self.num_handlers,
+            min_batch_size=self.min_batch_size,
+            max_batch_size=self.max_batch_size,
+            inference_max_length=self.inference_max_length,
+            torch_dtype=self.torch_dtype,
+            cache_dir=self.cache_dir,
+            device=self.device,
+            compression=self.compression,
+            stats_report_interval=self.stats_report_interval,
+            update_period=self.update_period,
+            expiration=self.expiration,
+            prefetch_batches=self.prefetch_batches,
+            sender_threads=self.sender_threads,
+            use_auth_token=self.use_auth_token,
+            load_in_8bit=self.load_in_8bit,
+            start=True,
+        )
+        try:
+            self.stop.wait()
+        finally:
+            self.module_container.shutdown()
+
+    def shutdown(self):
+        self.stop.set()
+
+        self.dht.shutdown()
+        self.dht.join()
 
+
+class ModuleContainer(threading.Thread):
+    """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
+
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        *,
+        device: torch.device,
+        num_connection_handlers: int,
+        throughput: float,
+        update_period: float,
+        expiration: Optional[float] = None,
+        start: bool,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.dht, self.module_backends = dht, module_backends
+        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+        self.conn_handlers = [
+            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+        ]
+        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.dht_handler_thread = ModuleAnnouncerThread(
+            self.module_backends,
+            dht,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
+        )
+        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    def run(self):
+        """
+        Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Serving {len(self.module_backends)} blocks:")
+        for expert_name, backend in self.module_backends.items():
+            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.module_backends:
+            self.dht_handler_thread.start()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.result()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+        cls,
+        *,
+        dht: DHT,
+        prefix: str,
+        converted_model_name_or_path: str,
+        block_config: BloomConfig,
+        memory_cache: MemoryCache,
+        throughput: float,
+        block_indices: List[int],
+        num_handlers: Optional[int],
+        min_batch_size: int,
+        max_batch_size: int,
+        inference_max_length: int,
+        torch_dtype: torch.dtype,
+        cache_dir: Optional[str],
+        device: Union[str, torch.device],
+        compression: CompressionType,
+        stats_report_interval: Optional[int],
+        update_period: float,
+        expiration: Optional[float],
+        prefetch_batches: int,
+        sender_threads: int,
+        use_auth_token: Optional[str],
+        load_in_8bit: bool,
+        start: bool,
+    ) -> ModuleContainer:
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
         declare_active_modules(
             dht,
@@ -245,33 +347,36 @@ class Server(threading.Thread):
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
-        Starts Server in a background thread. if await_ready, this method will wait until background server
+        Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
         if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+            raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
 
     @property
     def ready(self) -> mp.synchronize.Event:
         """
-        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+        An event (multiprocessing.Event) that is set when the container is ready to process requests.
 
         Example
         =======
-        >>> server.start()
-        >>> server.ready.wait(timeout=10)
-        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        >>> container.start()
+        >>> container.ready.wait(timeout=10)
+        >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
         """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
     def shutdown(self):
         """
-        Gracefully terminate the server, process-safe.
-        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        Gracefully terminate the container, process-safe.
+        Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         if self.module_backends:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
             declare_active_modules(
                 self.dht,
                 self.module_backends.keys(),
@@ -288,25 +393,18 @@ class Server(threading.Thread):
             process.join()
         logger.debug("Connection handlers terminated")
 
-        if self.module_backends:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
-
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
 
-        self.dht.shutdown()
-        self.dht.join()
-
         logger.debug(f"Shutting down runtime")
-
         self.runtime.shutdown()
-        logger.info("Server shut down succesfully")
+
+        logger.info("Module container shut down succesfully")
 
 
 class ModuleAnnouncerThread(threading.Thread):
-    """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
+    """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
 
     def __init__(
         self,