Sfoglia il codice sorgente

Server is not a thread anymore, so it catches KeyboardInterrupt

Aleksandr Borzunov 2 anni fa
parent
commit
52ea24730b
2 ha cambiato i file con 16 aggiunte e 10 eliminazioni
  1. 2 3
      cli/run_server.py
  2. 14 7
      src/server/server.py

+ 2 - 3
cli/run_server.py

@@ -124,10 +124,9 @@ def main():
     use_auth_token = args.pop("use_auth_token")
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
 
-    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
-
+    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
     try:
     try:
-        server.join()
+        server.run()
     except KeyboardInterrupt:
     except KeyboardInterrupt:
         logger.info("Caught KeyboardInterrupt, shutting down")
         logger.info("Caught KeyboardInterrupt, shutting down")
     finally:
     finally:

+ 14 - 7
src/server/server.py

@@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-class Server(threading.Thread):
+class Server:
     """
     """
     Runs ModuleContainer, periodically checks that the network is balanced,
     Runs ModuleContainer, periodically checks that the network is balanced,
     restarts the ModuleContainer with other layers if the imbalance is significant
     restarts the ModuleContainer with other layers if the imbalance is significant
@@ -68,13 +68,10 @@ class Server(threading.Thread):
         mean_block_selection_delay: float = 0.5,
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         load_in_8bit: bool = False,
-        start: bool,
         **kwargs,
         **kwargs,
     ):
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
 
 
-        super().__init__()
-
         self.converted_model_name_or_path = converted_model_name_or_path
         self.converted_model_name_or_path = converted_model_name_or_path
         self.num_handlers = num_handlers
         self.num_handlers = num_handlers
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
@@ -147,8 +144,6 @@ class Server(threading.Thread):
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_block_selection_delay = mean_block_selection_delay
 
 
         self.stop = threading.Event()
         self.stop = threading.Event()
-        if start:
-            self.start()
 
 
     def run(self):
     def run(self):
         while True:
         while True:
@@ -312,7 +307,19 @@ class ModuleContainer(threading.Thread):
                     min_batch_size=min_batch_size,
                     min_batch_size=min_batch_size,
                     max_batch_size=max_batch_size,
                     max_batch_size=max_batch_size,
                 )
                 )
-        finally:
+        except:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+            declare_active_modules(
+                dht,
+                module_uids,
+                expiration_time=get_dht_time() + expiration,
+                state=ServerState.OFFLINE,
+                throughput=throughput,
+            )
+            logger.info(f"Announced that blocks {module_uids} are offline")
+            raise
+        else:
             joining_announcer.stop.set()
             joining_announcer.stop.set()
             joining_announcer.join()
             joining_announcer.join()