ソースを参照

Fix "Too many open files" during rebalancing

Aleksandr Borzunov 2 年 前
コミット
75ed6ac49c
3 ファイル変更31 行追加5 行削除
  1. 3 0
      src/server/block_selection.py
  2. 17 0
      src/server/server.py
  3. 11 5
      src/server/task_pool.py

+ 3 - 0
src/server/block_selection.py

@@ -64,6 +64,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule
 def should_choose_other_blocks(
     local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
 ) -> bool:
+    if min_balance_quality > 1.0:
+        return True  # Forces rebalancing on each check (may be used for debugging purposes)
+
     spans, throughputs = _compute_spans(module_infos)
     initial_throughput = throughputs.min()
 

+ 17 - 0
src/server/server.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import gc
 import multiprocessing as mp
 import random
 import threading
@@ -7,6 +8,7 @@ import time
 from typing import Dict, List, Optional, Sequence, Union
 
 import numpy as np
+import psutil
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -187,6 +189,16 @@ class Server(threading.Thread):
             finally:
                 self.module_container.shutdown()
 
+            self._clean_memory_and_fds()
+
+    def _clean_memory_and_fds(self):
+        del self.module_container
+        gc.collect()  # In particular, this closes unused file descriptors
+
+        cur_proc = psutil.Process()
+        num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
+        logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
+
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
             return self.strict_block_indices
@@ -418,6 +430,11 @@ class ModuleContainer(threading.Thread):
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
 
+        logger.debug(f"Shutting down pools")
+        for pool in self.runtime.pools:
+            if pool.is_alive():
+                pool.shutdown()
+
         logger.debug(f"Shutting down runtime")
         self.runtime.shutdown()
 

+ 11 - 5
src/server/task_pool.py

@@ -70,6 +70,8 @@ class PrioritizedTaskPool(TaskPoolBase):
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
+
+        self._stop = mp.Event()
         if start:
             self.start()
 
@@ -89,10 +91,14 @@ class PrioritizedTaskPool(TaskPoolBase):
         self._prioritizer_thread.start()
         super().start()
 
-    def shutdown(self, timeout: Optional[float] = None):
-        self.submitted_tasks.put(None)
-        self.terminate()
-        self._prioritizer_thread.join(timeout)
+    def shutdown(self, timeout: float = 3):
+        self.submitted_tasks.put(None)  # Shuts down self._prioritizer_thread
+        self._stop.set()
+
+        self.join(timeout)
+        if self.is_alive():
+            logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
+            self.terminate()
 
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
@@ -154,7 +160,7 @@ class PrioritizedTaskPool(TaskPoolBase):
             task.future.set_exception(exception)
 
     def run(self, *args, **kwargs):
-        mp.Event().wait()
+        self._stop.wait()
 
     @property
     def empty(self):