Sfoglia il codice sorgente

Server is not a thread anymore, so it catches KeyboardInterrupt

Aleksandr Borzunov 2 anni fa
parent
commit
52ea24730b
2 ha cambiato i file con 16 aggiunte e 10 eliminazioni
  1. 2 3
      cli/run_server.py
  2. 14 7
      src/server/server.py

+ 2 - 3
cli/run_server.py

@@ -124,10 +124,9 @@ def main():
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
-    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
-
+    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
     try:
-        server.join()
+        server.run()
     except KeyboardInterrupt:
         logger.info("Caught KeyboardInterrupt, shutting down")
     finally:

+ 14 - 7
src/server/server.py

@@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class Server(threading.Thread):
+class Server:
     """
     Runs ModuleContainer, periodically checks that the network is balanced,
     restarts the ModuleContainer with other layers if the imbalance is significant
@@ -68,13 +68,10 @@ class Server(threading.Thread):
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
-        start: bool,
         **kwargs,
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
 
-        super().__init__()
-
         self.converted_model_name_or_path = converted_model_name_or_path
         self.num_handlers = num_handlers
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
@@ -147,8 +144,6 @@ class Server(threading.Thread):
         self.mean_block_selection_delay = mean_block_selection_delay
 
         self.stop = threading.Event()
-        if start:
-            self.start()
 
     def run(self):
         while True:
@@ -312,7 +307,19 @@ class ModuleContainer(threading.Thread):
                     min_batch_size=min_batch_size,
                     max_batch_size=max_batch_size,
                 )
-        finally:
+        except:
+            joining_announcer.stop.set()
+            joining_announcer.join()
+            declare_active_modules(
+                dht,
+                module_uids,
+                expiration_time=get_dht_time() + expiration,
+                state=ServerState.OFFLINE,
+                throughput=throughput,
+            )
+            logger.info(f"Announced that blocks {module_uids} are offline")
+            raise
+        else:
             joining_announcer.stop.set()
             joining_announcer.join()