Prechádzať zdrojové kódy

Fix OOMs during server rebalancing (#150)

The cause of OOMs were the cyclic references `TransformerBackend <-> PrioritizedTaskPool` that could not have been garbage collected properly. Still, I've added explicit tensor removal just in case.
Alexander Borzunov 2 rokov pred
rodič
commit
e4dc938dfe
2 zmenil súbory, kde vykonal 33 pridanie a 7 odobranie
  1. 10 0
      src/petals/server/backend.py
  2. 23 7
      src/petals/server/server.py

+ 10 - 0
src/petals/server/backend.py

@@ -85,3 +85,13 @@ class TransformerBackend(ModuleBackend):
     def get_info(self) -> Dict[str, Any]:
         """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
         return dict(super().get_info(), inference_schema=self.inference_schema)
+
+    def shutdown(self):
+        # Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
+        self.forward_pool = self.backward_pool = self.inference_pool = None
+
+        # Explicitly free the GPU memory. This is not necessary at the time this code is written,
+        # but may help to avoid future issues when the module is not garbage-collected for some reasons
+        dummy = torch.tensor([])
+        for p in self.module.parameters():
+            p.data = dummy

+ 23 - 7
src/petals/server/server.py

@@ -235,8 +235,8 @@ class Server:
                     if self.stop.wait(timeout):
                         return
 
-                    if not self.module_container.handlers_alive:
-                        logger.warning("One of connection handlers crashed, restarting the server")
+                    if not self.module_container.is_healthy():
+                        logger.warning("One of subprocesses crashed, restarting the server")
                         break
 
                     if self._should_choose_other_blocks():
@@ -252,8 +252,19 @@ class Server:
         gc.collect()  # In particular, this closes unused file descriptors
 
         cur_proc = psutil.Process()
-        num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
-        logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
+        num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
+        logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
+
+        if self.device.type == "cuda":
+            torch.cuda.empty_cache()
+
+            allocated_vram = torch.cuda.memory_allocated(self.device)
+            reserved_vram = torch.cuda.memory_reserved(self.device)
+            gib = 1024**3
+            logger.info(
+                f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
+                f"{reserved_vram / gib:.1f} GiB reserved memory"
+            )
 
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
@@ -470,9 +481,10 @@ class ModuleContainer(threading.Thread):
         """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
-    @property
-    def handlers_alive(self) -> bool:
-        return all(handler.is_alive() for handler in self.conn_handlers)
+    def is_healthy(self) -> bool:
+        return all(handler.is_alive() for handler in self.conn_handlers) and all(
+            pool.is_alive() for pool in self.runtime.pools
+        )
 
     def shutdown(self):
         """
@@ -510,6 +522,10 @@ class ModuleContainer(threading.Thread):
         logger.debug(f"Shutting down runtime")
         self.runtime.shutdown()
 
+        logger.debug("Shutting down backends")
+        for backend in self.module_backends.values():
+            backend.shutdown()
+
         logger.info("Module container shut down successfully")