Ver Fonte

Announce JOINING periodically

Aleksandr Borzunov há 2 anos atrás
pai
commit
06a8246ae9
1 ficheiros alterados com 120 adições e 112 exclusões
  1. 120 112
      src/server/server.py

+ 120 - 112
src/server/server.py

@@ -231,6 +231,106 @@ class Server(threading.Thread):
 class ModuleContainer(threading.Thread):
     """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
 
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+        cls,
+        *,
+        dht: DHT,
+        prefix: str,
+        converted_model_name_or_path: str,
+        block_config: BloomConfig,
+        memory_cache: MemoryCache,
+        throughput: float,
+        block_indices: List[int],
+        num_handlers: Optional[int],
+        min_batch_size: int,
+        max_batch_size: int,
+        inference_max_length: int,
+        torch_dtype: torch.dtype,
+        cache_dir: Optional[str],
+        device: Union[str, torch.device],
+        compression: CompressionType,
+        stats_report_interval: Optional[int],
+        update_period: float,
+        expiration: Optional[float],
+        prefetch_batches: int,
+        sender_threads: int,
+        use_auth_token: Optional[str],
+        load_in_8bit: bool,
+        start: bool,
+    ) -> ModuleContainer:
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        joining_announcer = ModuleAnnouncerThread(
+            module_uids,
+            dht,
+            ServerState.JOINING,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
+        )
+        joining_announcer.start()
+        logger.info(f"Announced that blocks {block_indices} are joining")
+
+        try:
+            blocks = {}
+            for module_uid, block_index in zip(module_uids, block_indices):
+                block = load_pretrained_block(
+                    converted_model_name_or_path,
+                    block_index,
+                    block_config,
+                    torch_dtype=torch_dtype,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                )
+
+                if load_in_8bit:
+                    dtype = block.input_layernorm.weight.dtype
+                    block = replace_8bit_linear(block)
+
+                block = block.to(device)
+                for param in block.parameters():
+                    param.requires_grad = False
+
+                blocks[module_uid] = TransformerBackend(
+                    module_uid,
+                    block,
+                    memory_cache=memory_cache,
+                    backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                    args_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                        ),
+                    ),
+                    kwargs_schema={},
+                    outputs_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                        ),
+                    ),
+                    min_batch_size=min_batch_size,
+                    max_batch_size=max_batch_size,
+                )
+        finally:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+
+        return cls(
+            dht,
+            blocks,
+            throughput=throughput,
+            num_connection_handlers=num_handlers,
+            inference_max_length=inference_max_length,
+            device=device,
+            stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
+            prefetch_batches=prefetch_batches,
+            sender_threads=sender_threads,
+            start=start,
+        )
+
     def __init__(
         self,
         dht: DHT,
@@ -253,9 +353,10 @@ class ModuleContainer(threading.Thread):
             for _ in range(num_connection_handlers)
         ]
         self.runtime = Runtime(self.module_backends, **kwargs)
-        self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends,
+        self.online_announcer = ModuleAnnouncerThread(
+            list(self.module_backends.keys()),
             dht,
+            ServerState.ONLINE,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -279,8 +380,7 @@ class ModuleContainer(threading.Thread):
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
 
-        if self.module_backends:
-            self.dht_handler_thread.start()
+        self.online_announcer.start()
 
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
@@ -290,99 +390,6 @@ class ModuleContainer(threading.Thread):
 
         self.runtime.run()
 
-    # noinspection PyMethodOverriding
-    @classmethod
-    def create(
-        cls,
-        *,
-        dht: DHT,
-        prefix: str,
-        converted_model_name_or_path: str,
-        block_config: BloomConfig,
-        memory_cache: MemoryCache,
-        throughput: float,
-        block_indices: List[int],
-        num_handlers: Optional[int],
-        min_batch_size: int,
-        max_batch_size: int,
-        inference_max_length: int,
-        torch_dtype: torch.dtype,
-        cache_dir: Optional[str],
-        device: Union[str, torch.device],
-        compression: CompressionType,
-        stats_report_interval: Optional[int],
-        update_period: float,
-        expiration: Optional[float],
-        prefetch_batches: int,
-        sender_threads: int,
-        use_auth_token: Optional[str],
-        load_in_8bit: bool,
-        start: bool,
-    ) -> ModuleContainer:
-        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
-        declare_active_modules(
-            dht,
-            module_uids,
-            expiration_time=get_dht_time() + expiration,
-            state=ServerState.JOINING,
-            throughput=throughput,
-        )
-        logger.info(f"Announced that blocks {block_indices} are joining")
-
-        blocks = {}
-        for module_uid, block_index in zip(module_uids, block_indices):
-            block = load_pretrained_block(
-                converted_model_name_or_path,
-                block_index,
-                block_config,
-                torch_dtype=torch_dtype,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-            )
-
-            if load_in_8bit:
-                dtype = block.input_layernorm.weight.dtype
-                block = replace_8bit_linear(block)
-
-            block = block.to(device)
-            for param in block.parameters():
-                param.requires_grad = False
-
-            blocks[module_uid] = TransformerBackend(
-                module_uid,
-                block,
-                memory_cache=memory_cache,
-                backend_dtype=None if torch_dtype == "auto" else torch_dtype,
-                args_schema=(
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
-                ),
-                kwargs_schema={},
-                outputs_schema=(
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
-                ),
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        return cls(
-            dht,
-            blocks,
-            throughput=throughput,
-            num_connection_handlers=num_handlers,
-            inference_max_length=inference_max_length,
-            device=device,
-            stats_report_interval=stats_report_interval,
-            update_period=update_period,
-            expiration=expiration,
-            prefetch_batches=prefetch_batches,
-            sender_threads=sender_threads,
-            start=start,
-        )
-
     def run_in_background(self, await_ready=True, timeout=None):
         """
         Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
@@ -411,18 +418,17 @@ class ModuleContainer(threading.Thread):
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
-        if self.module_backends:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
+        self.online_announcer.stop.set()
+        self.online_announcer.join()
 
-            declare_active_modules(
-                self.dht,
-                self.module_backends.keys(),
-                expiration_time=get_dht_time() + self.expiration,
-                state=ServerState.OFFLINE,
-                throughput=self.throughput,
-            )
-            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+        declare_active_modules(
+            self.dht,
+            self.module_backends.keys(),
+            expiration_time=get_dht_time() + self.expiration,
+            state=ServerState.OFFLINE,
+            throughput=self.throughput,
+        )
+        logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
         self.ready.clear()
 
@@ -450,8 +456,9 @@ class ModuleAnnouncerThread(threading.Thread):
 
     def __init__(
         self,
-        module_backends: Dict[str, TransformerBackend],
+        module_uids: List[str],
         dht: DHT,
+        state: ServerState,
         *,
         throughput: float,
         update_period: float = 30,
@@ -459,8 +466,9 @@ class ModuleAnnouncerThread(threading.Thread):
         **kwargs,
     ):
         super().__init__(**kwargs)
-        self.module_backends = module_backends
+        self.module_uids = module_uids
         self.dht = dht
+        self.state = state
         self.throughput = throughput
         self.update_period = update_period
         self.expiration = expiration
@@ -470,9 +478,9 @@ class ModuleAnnouncerThread(threading.Thread):
         while True:
             declare_active_modules(
                 self.dht,
-                self.module_backends.keys(),
+                self.module_uids,
                 expiration_time=get_dht_time() + self.expiration,
-                state=ServerState.ONLINE,
+                state=self.state,
                 throughput=self.throughput,
             )
             if self.stop.wait(self.update_period):