Browse Source

Fix OOMs during server rebalancing (#150)

The cause of OOMs were the cyclic references `TransformerBackend <-> PrioritizedTaskPool` that could not have been garbage collected properly. Still, I've added explicit tensor removal just in case.
Alexander Borzunov 2 years ago
parent
commit
e4dc938dfe
2 changed files with 33 additions and 7 deletions
  1. 10 0
      src/petals/server/backend.py
  2. 23 7
      src/petals/server/server.py

+ 10 - 0
src/petals/server/backend.py

@@ -85,3 +85,13 @@ class TransformerBackend(ModuleBackend):
     def get_info(self) -> Dict[str, Any]:
     def get_info(self) -> Dict[str, Any]:
         """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
         """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
         return dict(super().get_info(), inference_schema=self.inference_schema)
         return dict(super().get_info(), inference_schema=self.inference_schema)
+
+    def shutdown(self):
+        # Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
+        self.forward_pool = self.backward_pool = self.inference_pool = None
+
+        # Explicitly free the GPU memory. This is not necessary at the time this code is written,
+        # but may help to avoid future issues when the module is not garbage-collected for some reasons
+        dummy = torch.tensor([])
+        for p in self.module.parameters():
+            p.data = dummy

+ 23 - 7
src/petals/server/server.py

@@ -235,8 +235,8 @@ class Server:
                     if self.stop.wait(timeout):
                     if self.stop.wait(timeout):
                         return
                         return
 
 
-                    if not self.module_container.handlers_alive:
-                        logger.warning("One of connection handlers crashed, restarting the server")
+                    if not self.module_container.is_healthy():
+                        logger.warning("One of subprocesses crashed, restarting the server")
                         break
                         break
 
 
                     if self._should_choose_other_blocks():
                     if self._should_choose_other_blocks():
@@ -252,8 +252,19 @@ class Server:
         gc.collect()  # In particular, this closes unused file descriptors
         gc.collect()  # In particular, this closes unused file descriptors
 
 
         cur_proc = psutil.Process()
         cur_proc = psutil.Process()
-        num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
-        logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
+        num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
+        logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
+
+        if self.device.type == "cuda":
+            torch.cuda.empty_cache()
+
+            allocated_vram = torch.cuda.memory_allocated(self.device)
+            reserved_vram = torch.cuda.memory_reserved(self.device)
+            gib = 1024**3
+            logger.info(
+                f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
+                f"{reserved_vram / gib:.1f} GiB reserved memory"
+            )
 
 
     def _choose_blocks(self) -> List[int]:
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
         if self.strict_block_indices is not None:
@@ -470,9 +481,10 @@ class ModuleContainer(threading.Thread):
         """
         """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
 
-    @property
-    def handlers_alive(self) -> bool:
-        return all(handler.is_alive() for handler in self.conn_handlers)
+    def is_healthy(self) -> bool:
+        return all(handler.is_alive() for handler in self.conn_handlers) and all(
+            pool.is_alive() for pool in self.runtime.pools
+        )
 
 
     def shutdown(self):
     def shutdown(self):
         """
         """
@@ -510,6 +522,10 @@ class ModuleContainer(threading.Thread):
         logger.debug(f"Shutting down runtime")
         logger.debug(f"Shutting down runtime")
         self.runtime.shutdown()
         self.runtime.shutdown()
 
 
+        logger.debug("Shutting down backends")
+        for backend in self.module_backends.values():
+            backend.shutdown()
+
         logger.info("Module container shut down successfully")
         logger.info("Module container shut down successfully")