瀏覽代碼

Fix ModuleContainer.shutdown() and its usages

Aleksandr Borzunov 2 年之前
父節點
當前提交
f3ea120c81
共有 1 個文件被更改,包括 8 次插入13 次删除
  1. 8 13
      src/server/server.py

+ 8 - 13
src/server/server.py

@@ -60,7 +60,7 @@ class Server(threading.Thread):
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         mean_block_selection_delay: float = 0.5,
-        mean_balance_check_period: float = 60,  # TODO:
+        mean_balance_check_period: float = 300,  # TODO:
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         *,
@@ -175,6 +175,7 @@ class Server(threading.Thread):
 
                 while True:
                     timeout = random.random() * 2 * self.mean_balance_check_period
+                    # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
                     if self.stop.wait(timeout):
                         return
                     if self._should_choose_other_blocks():
@@ -251,7 +252,7 @@ class ModuleContainer(threading.Thread):
     def run(self):
         """
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs hivemind.Runtime (self.runtime) to process incoming requests.
+        runs Runtime (self.runtime) to process incoming requests.
         """
         logger.info(f"Serving {len(self.module_backends)} blocks:")
         for expert_name, backend in self.module_backends.items():
@@ -267,15 +268,10 @@ class ModuleContainer(threading.Thread):
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
 
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.result()
+        for handler in self.conn_handlers:
+            handler.run_in_background()
 
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
+        self.runtime.run()
 
     # noinspection PyMethodOverriding
     @classmethod
@@ -413,9 +409,8 @@ class ModuleContainer(threading.Thread):
 
         self.ready.clear()
 
-        for process in self.conn_handlers:
-            process.terminate()
-            process.join()
+        for handler in self.conn_handlers:
+            handler.shutdown()
         logger.debug("Connection handlers terminated")
 
         if self.checkpoint_saver is not None: