Explorar o código

Announce JOINING periodically

Aleksandr Borzunov %!s(int64=2) %!d(string=hai) anos
pai
achega
06a8246ae9
Modificáronse 1 ficheiros con 120 adicións e 112 borrados
  1. 120 112
      src/server/server.py

+ 120 - 112
src/server/server.py

@@ -231,6 +231,106 @@ class Server(threading.Thread):
 class ModuleContainer(threading.Thread):
 class ModuleContainer(threading.Thread):
     """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
     """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
 
 
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+        cls,
+        *,
+        dht: DHT,
+        prefix: str,
+        converted_model_name_or_path: str,
+        block_config: BloomConfig,
+        memory_cache: MemoryCache,
+        throughput: float,
+        block_indices: List[int],
+        num_handlers: Optional[int],
+        min_batch_size: int,
+        max_batch_size: int,
+        inference_max_length: int,
+        torch_dtype: torch.dtype,
+        cache_dir: Optional[str],
+        device: Union[str, torch.device],
+        compression: CompressionType,
+        stats_report_interval: Optional[int],
+        update_period: float,
+        expiration: Optional[float],
+        prefetch_batches: int,
+        sender_threads: int,
+        use_auth_token: Optional[str],
+        load_in_8bit: bool,
+        start: bool,
+    ) -> ModuleContainer:
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        joining_announcer = ModuleAnnouncerThread(
+            module_uids,
+            dht,
+            ServerState.JOINING,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
+        )
+        joining_announcer.start()
+        logger.info(f"Announced that blocks {block_indices} are joining")
+
+        try:
+            blocks = {}
+            for module_uid, block_index in zip(module_uids, block_indices):
+                block = load_pretrained_block(
+                    converted_model_name_or_path,
+                    block_index,
+                    block_config,
+                    torch_dtype=torch_dtype,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                )
+
+                if load_in_8bit:
+                    dtype = block.input_layernorm.weight.dtype
+                    block = replace_8bit_linear(block)
+
+                block = block.to(device)
+                for param in block.parameters():
+                    param.requires_grad = False
+
+                blocks[module_uid] = TransformerBackend(
+                    module_uid,
+                    block,
+                    memory_cache=memory_cache,
+                    backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                    args_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                        ),
+                    ),
+                    kwargs_schema={},
+                    outputs_schema=(
+                        BatchTensorDescriptor(
+                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                        ),
+                    ),
+                    min_batch_size=min_batch_size,
+                    max_batch_size=max_batch_size,
+                )
+        finally:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+
+        return cls(
+            dht,
+            blocks,
+            throughput=throughput,
+            num_connection_handlers=num_handlers,
+            inference_max_length=inference_max_length,
+            device=device,
+            stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
+            prefetch_batches=prefetch_batches,
+            sender_threads=sender_threads,
+            start=start,
+        )
+
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
@@ -253,9 +353,10 @@ class ModuleContainer(threading.Thread):
             for _ in range(num_connection_handlers)
             for _ in range(num_connection_handlers)
         ]
         ]
         self.runtime = Runtime(self.module_backends, **kwargs)
         self.runtime = Runtime(self.module_backends, **kwargs)
-        self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends,
+        self.online_announcer = ModuleAnnouncerThread(
+            list(self.module_backends.keys()),
             dht,
             dht,
+            ServerState.ONLINE,
             throughput=throughput,
             throughput=throughput,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
@@ -279,8 +380,7 @@ class ModuleContainer(threading.Thread):
         if not self.dht.is_alive():
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
             self.dht.run_in_background(await_ready=True)
 
 
-        if self.module_backends:
-            self.dht_handler_thread.start()
+        self.online_announcer.start()
 
 
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()
@@ -290,99 +390,6 @@ class ModuleContainer(threading.Thread):
 
 
         self.runtime.run()
         self.runtime.run()
 
 
-    # noinspection PyMethodOverriding
-    @classmethod
-    def create(
-        cls,
-        *,
-        dht: DHT,
-        prefix: str,
-        converted_model_name_or_path: str,
-        block_config: BloomConfig,
-        memory_cache: MemoryCache,
-        throughput: float,
-        block_indices: List[int],
-        num_handlers: Optional[int],
-        min_batch_size: int,
-        max_batch_size: int,
-        inference_max_length: int,
-        torch_dtype: torch.dtype,
-        cache_dir: Optional[str],
-        device: Union[str, torch.device],
-        compression: CompressionType,
-        stats_report_interval: Optional[int],
-        update_period: float,
-        expiration: Optional[float],
-        prefetch_batches: int,
-        sender_threads: int,
-        use_auth_token: Optional[str],
-        load_in_8bit: bool,
-        start: bool,
-    ) -> ModuleContainer:
-        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
-        declare_active_modules(
-            dht,
-            module_uids,
-            expiration_time=get_dht_time() + expiration,
-            state=ServerState.JOINING,
-            throughput=throughput,
-        )
-        logger.info(f"Announced that blocks {block_indices} are joining")
-
-        blocks = {}
-        for module_uid, block_index in zip(module_uids, block_indices):
-            block = load_pretrained_block(
-                converted_model_name_or_path,
-                block_index,
-                block_config,
-                torch_dtype=torch_dtype,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-            )
-
-            if load_in_8bit:
-                dtype = block.input_layernorm.weight.dtype
-                block = replace_8bit_linear(block)
-
-            block = block.to(device)
-            for param in block.parameters():
-                param.requires_grad = False
-
-            blocks[module_uid] = TransformerBackend(
-                module_uid,
-                block,
-                memory_cache=memory_cache,
-                backend_dtype=None if torch_dtype == "auto" else torch_dtype,
-                args_schema=(
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
-                ),
-                kwargs_schema={},
-                outputs_schema=(
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
-                ),
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        return cls(
-            dht,
-            blocks,
-            throughput=throughput,
-            num_connection_handlers=num_handlers,
-            inference_max_length=inference_max_length,
-            device=device,
-            stats_report_interval=stats_report_interval,
-            update_period=update_period,
-            expiration=expiration,
-            prefetch_batches=prefetch_batches,
-            sender_threads=sender_threads,
-            start=start,
-        )
-
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
         Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
         Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
@@ -411,18 +418,17 @@ class ModuleContainer(threading.Thread):
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
-        if self.module_backends:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
+        self.online_announcer.stop.set()
+        self.online_announcer.join()
 
 
-            declare_active_modules(
-                self.dht,
-                self.module_backends.keys(),
-                expiration_time=get_dht_time() + self.expiration,
-                state=ServerState.OFFLINE,
-                throughput=self.throughput,
-            )
-            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+        declare_active_modules(
+            self.dht,
+            self.module_backends.keys(),
+            expiration_time=get_dht_time() + self.expiration,
+            state=ServerState.OFFLINE,
+            throughput=self.throughput,
+        )
+        logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
 
         self.ready.clear()
         self.ready.clear()
 
 
@@ -450,8 +456,9 @@ class ModuleAnnouncerThread(threading.Thread):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        module_backends: Dict[str, TransformerBackend],
+        module_uids: List[str],
         dht: DHT,
         dht: DHT,
+        state: ServerState,
         *,
         *,
         throughput: float,
         throughput: float,
         update_period: float = 30,
         update_period: float = 30,
@@ -459,8 +466,9 @@ class ModuleAnnouncerThread(threading.Thread):
         **kwargs,
         **kwargs,
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        self.module_backends = module_backends
+        self.module_uids = module_uids
         self.dht = dht
         self.dht = dht
+        self.state = state
         self.throughput = throughput
         self.throughput = throughput
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
@@ -470,9 +478,9 @@ class ModuleAnnouncerThread(threading.Thread):
         while True:
         while True:
             declare_active_modules(
             declare_active_modules(
                 self.dht,
                 self.dht,
-                self.module_backends.keys(),
+                self.module_uids,
                 expiration_time=get_dht_time() + self.expiration,
                 expiration_time=get_dht_time() + self.expiration,
-                state=ServerState.ONLINE,
+                state=self.state,
                 throughput=self.throughput,
                 throughput=self.throughput,
             )
             )
             if self.stop.wait(self.update_period):
             if self.stop.wait(self.update_period):