Explorar o código

Make ServerState announcements work better (#93)

- Before this PR, `ServerState.JOINING` was announced only once. This announcement quickly expires in case of the full-size BLOOM, since loading blocks takes several minutes. This PR fixes it, so `ServerState.JOINING` is announced periodically in a thread until blocks are loaded.

- This PR also makes the `Server` class a non-thread, so it runs in the main thread and can catch `KeyboardInterrupt`. This is important, since if we are downloading blocks right now, we need to stop it and send the `ServerState.OFFLINE` message. Note that `ModuleContainer` is still a thread.

- (minor) For the sake of readability, I moved the `ModuleContainer.create()` definition, so it is now defined before `Server.__init__()` (this is because `.create()` is invoked first).
Alexander Borzunov %!s(int64=2) %!d(string=hai) anos
pai
achega
8a73b41a42
Modificáronse 2 ficheiros con 135 adicións e 121 borrados
  1. 2 3
      cli/run_server.py
  2. 133 118
      src/server/server.py

+ 2 - 3
cli/run_server.py

@@ -124,10 +124,9 @@ def main():
     use_auth_token = args.pop("use_auth_token")
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
 
-    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
-
+    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
     try:
     try:
-        server.join()
+        server.run()
     except KeyboardInterrupt:
     except KeyboardInterrupt:
         logger.info("Caught KeyboardInterrupt, shutting down")
         logger.info("Caught KeyboardInterrupt, shutting down")
     finally:
     finally:

+ 133 - 118
src/server/server.py

@@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-class Server(threading.Thread):
+class Server:
     """
     """
     Runs ModuleContainer, periodically checks that the network is balanced,
     Runs ModuleContainer, periodically checks that the network is balanced,
     restarts the ModuleContainer with other layers if the imbalance is significant
     restarts the ModuleContainer with other layers if the imbalance is significant
@@ -68,13 +68,10 @@ class Server(threading.Thread):
         mean_block_selection_delay: float = 0.5,
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         load_in_8bit: bool = False,
-        start: bool,
         **kwargs,
         **kwargs,
     ):
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
 
 
-        super().__init__()
-
         self.converted_model_name_or_path = converted_model_name_or_path
         self.converted_model_name_or_path = converted_model_name_or_path
         self.num_handlers = num_handlers
         self.num_handlers = num_handlers
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
@@ -147,8 +144,6 @@ class Server(threading.Thread):
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_block_selection_delay = mean_block_selection_delay
 
 
         self.stop = threading.Event()
         self.stop = threading.Event()
-        if start:
-            self.start()
 
 
     def run(self):
     def run(self):
         while True:
         while True:
@@ -231,6 +226,118 @@ class Server(threading.Thread):
 class ModuleContainer(threading.Thread):
 class ModuleContainer(threading.Thread):
     """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
     """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
 
 
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+        cls,
+        *,
+        dht: DHT,
+        prefix: str,
+        converted_model_name_or_path: str,
+        block_config: BloomConfig,
+        memory_cache: MemoryCache,
+        throughput: float,
+        block_indices: List[int],
+        num_handlers: Optional[int],
+        min_batch_size: int,
+        max_batch_size: int,
+        inference_max_length: int,
+        torch_dtype: torch.dtype,
+        cache_dir: Optional[str],
+        device: Union[str, torch.device],
+        compression: CompressionType,
+        stats_report_interval: Optional[int],
+        update_period: float,
+        expiration: Optional[float],
+        prefetch_batches: int,
+        sender_threads: int,
+        use_auth_token: Optional[str],
+        load_in_8bit: bool,
+        start: bool,
+    ) -> ModuleContainer:
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        joining_announcer = ModuleAnnouncerThread(
+            module_uids,
+            dht,
+            ServerState.JOINING,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
+        )
+        joining_announcer.start()
+        logger.info(f"Announced that blocks {block_indices} are joining")
+
+        try:
+            blocks = {}
+            for module_uid, block_index in zip(module_uids, block_indices):
+                block = load_pretrained_block(
+                    converted_model_name_or_path,
+                    block_index,
+                    block_config,
+                    torch_dtype=torch_dtype,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                )
+
+                if load_in_8bit:
+                    dtype = block.input_layernorm.weight.dtype
+                    block = replace_8bit_linear(block)
+
+                block = block.to(device)
+                for param in block.parameters():
+                    param.requires_grad = False
+
+                blocks[module_uid] = TransformerBackend(
+                    module_uid,
+                    block,
+                    memory_cache=memory_cache,
+                    backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                    args_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                        ),
+                    ),
+                    kwargs_schema={},
+                    outputs_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                        ),
+                    ),
+                    min_batch_size=min_batch_size,
+                    max_batch_size=max_batch_size,
+                )
+        except:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+            declare_active_modules(
+                dht,
+                module_uids,
+                expiration_time=get_dht_time() + expiration,
+                state=ServerState.OFFLINE,
+                throughput=throughput,
+            )
+            logger.info(f"Announced that blocks {module_uids} are offline")
+            raise
+        else:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+
+        return cls(
+            dht,
+            blocks,
+            throughput=throughput,
+            num_connection_handlers=num_handlers,
+            inference_max_length=inference_max_length,
+            device=device,
+            stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
+            prefetch_batches=prefetch_batches,
+            sender_threads=sender_threads,
+            start=start,
+        )
+
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
@@ -253,9 +360,10 @@ class ModuleContainer(threading.Thread):
             for _ in range(num_connection_handlers)
             for _ in range(num_connection_handlers)
         ]
         ]
         self.runtime = Runtime(self.module_backends, **kwargs)
         self.runtime = Runtime(self.module_backends, **kwargs)
-        self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends,
+        self.online_announcer = ModuleAnnouncerThread(
+            list(self.module_backends.keys()),
             dht,
             dht,
+            ServerState.ONLINE,
             throughput=throughput,
             throughput=throughput,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
@@ -279,8 +387,7 @@ class ModuleContainer(threading.Thread):
         if not self.dht.is_alive():
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
             self.dht.run_in_background(await_ready=True)
 
 
-        if self.module_backends:
-            self.dht_handler_thread.start()
+        self.online_announcer.start()
 
 
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()
@@ -290,99 +397,6 @@ class ModuleContainer(threading.Thread):
 
 
         self.runtime.run()
         self.runtime.run()
 
 
-    # noinspection PyMethodOverriding
-    @classmethod
-    def create(
-        cls,
-        *,
-        dht: DHT,
-        prefix: str,
-        converted_model_name_or_path: str,
-        block_config: BloomConfig,
-        memory_cache: MemoryCache,
-        throughput: float,
-        block_indices: List[int],
-        num_handlers: Optional[int],
-        min_batch_size: int,
-        max_batch_size: int,
-        inference_max_length: int,
-        torch_dtype: torch.dtype,
-        cache_dir: Optional[str],
-        device: Union[str, torch.device],
-        compression: CompressionType,
-        stats_report_interval: Optional[int],
-        update_period: float,
-        expiration: Optional[float],
-        prefetch_batches: int,
-        sender_threads: int,
-        use_auth_token: Optional[str],
-        load_in_8bit: bool,
-        start: bool,
-    ) -> ModuleContainer:
-        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
-        declare_active_modules(
-            dht,
-            module_uids,
-            expiration_time=get_dht_time() + expiration,
-            state=ServerState.JOINING,
-            throughput=throughput,
-        )
-        logger.info(f"Announced that blocks {block_indices} are joining")
-
-        blocks = {}
-        for module_uid, block_index in zip(module_uids, block_indices):
-            block = load_pretrained_block(
-                converted_model_name_or_path,
-                block_index,
-                block_config,
-                torch_dtype=torch_dtype,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-            )
-
-            if load_in_8bit:
-                dtype = block.input_layernorm.weight.dtype
-                block = replace_8bit_linear(block)
-
-            block = block.to(device)
-            for param in block.parameters():
-                param.requires_grad = False
-
-            blocks[module_uid] = TransformerBackend(
-                module_uid,
-                block,
-                memory_cache=memory_cache,
-                backend_dtype=None if torch_dtype == "auto" else torch_dtype,
-                args_schema=(
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
-                ),
-                kwargs_schema={},
-                outputs_schema=(
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
-                ),
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        return cls(
-            dht,
-            blocks,
-            throughput=throughput,
-            num_connection_handlers=num_handlers,
-            inference_max_length=inference_max_length,
-            device=device,
-            stats_report_interval=stats_report_interval,
-            update_period=update_period,
-            expiration=expiration,
-            prefetch_batches=prefetch_batches,
-            sender_threads=sender_threads,
-            start=start,
-        )
-
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
         Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
         Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
@@ -411,18 +425,17 @@ class ModuleContainer(threading.Thread):
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
-        if self.module_backends:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
+        self.online_announcer.stop.set()
+        self.online_announcer.join()
 
 
-            declare_active_modules(
-                self.dht,
-                self.module_backends.keys(),
-                expiration_time=get_dht_time() + self.expiration,
-                state=ServerState.OFFLINE,
-                throughput=self.throughput,
-            )
-            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+        declare_active_modules(
+            self.dht,
+            self.module_backends.keys(),
+            expiration_time=get_dht_time() + self.expiration,
+            state=ServerState.OFFLINE,
+            throughput=self.throughput,
+        )
+        logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
 
         self.ready.clear()
         self.ready.clear()
 
 
@@ -450,8 +463,9 @@ class ModuleAnnouncerThread(threading.Thread):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        module_backends: Dict[str, TransformerBackend],
+        module_uids: List[str],
         dht: DHT,
         dht: DHT,
+        state: ServerState,
         *,
         *,
         throughput: float,
         throughput: float,
         update_period: float = 30,
         update_period: float = 30,
@@ -459,8 +473,9 @@ class ModuleAnnouncerThread(threading.Thread):
         **kwargs,
         **kwargs,
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        self.module_backends = module_backends
+        self.module_uids = module_uids
         self.dht = dht
         self.dht = dht
+        self.state = state
         self.throughput = throughput
         self.throughput = throughput
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
@@ -470,9 +485,9 @@ class ModuleAnnouncerThread(threading.Thread):
         while True:
         while True:
             declare_active_modules(
             declare_active_modules(
                 self.dht,
                 self.dht,
-                self.module_backends.keys(),
+                self.module_uids,
                 expiration_time=get_dht_time() + self.expiration,
                 expiration_time=get_dht_time() + self.expiration,
-                state=ServerState.ONLINE,
+                state=self.state,
                 throughput=self.throughput,
                 throughput=self.throughput,
             )
             )
             if self.stop.wait(self.update_period):
             if self.stop.wait(self.update_period):