Переглянути джерело

Reset MemoryCache during rebalancings (#154)

Before this PR, if there were open inference sessions right when rebalancing is triggered, their cache was never properly destroyed.
Alexander Borzunov 2 роки тому
батько
коміт
73df69a117
2 змінених файлів з 22 додано та 18 видалено
  1. 2 10
      src/petals/server/memory_cache.py
  2. 20 8
      src/petals/server/server.py

+ 2 - 10
src/petals/server/memory_cache.py

@@ -33,12 +33,10 @@ class MemoryCache:
         self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
-        self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
-        self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
+        self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
         self.runtime_pid = os.getpid()
 
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
-        self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
         self._lock_acquire_memory = mp.Lock()
         self._memory_freed_event = mp.Event()
 
@@ -83,14 +81,12 @@ class MemoryCache:
                     allocated_handle = int(self.handle_counter)
                     self.current_size_bytes += allocated_size_bytes
                     self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                    self._pending_messages.value += 1
                     self._pipe_send.send((allocated_handle, descr))
 
             yield allocated_handle
         finally:
             if allocated_handle is not None:
                 async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-                    self._pending_messages.value += 1
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self.current_size_bytes -= allocated_size_bytes
                 self._memory_freed_event.set()
@@ -122,13 +118,9 @@ class MemoryCache:
         # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
 
         with self._lock_metadata:
-            if self._allocated_tensors is None:
-                self._allocated_tensors = {}
-
             # read creation/deletion requests from connection handlers
-            for i in range(int(self._pending_messages.value)):
+            while self._pipe_recv.poll():
                 recv_handle, recv_data = self._pipe_recv.recv()
-                self._pending_messages.value -= 1
                 if isinstance(recv_data, TensorDescriptor):
                     self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
                 elif recv_data is None:

+ 20 - 8
src/petals/server/server.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import gc
+import itertools
 import math
 import multiprocessing as mp
 import random
@@ -72,8 +73,8 @@ class Server:
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         balance_quality: float = 0.75,
-        mean_balance_check_period: float = 60,
-        mean_block_selection_delay: float = 0.5,
+        mean_balance_check_period: float = 120,
+        mean_block_selection_delay: float = 2.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: Optional[bool] = None,
         **kwargs,
@@ -119,7 +120,6 @@ class Server:
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
             logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
-            logger.info("Please check that your server is reachable at http://health.petals.ml")
         else:
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
@@ -157,8 +157,8 @@ class Server:
         if attn_cache_size is None:
             # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
             attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
+        self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
         logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
-        self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
 
         if cache_dir is None:
             cache_dir = DEFAULT_CACHE_DIR
@@ -211,7 +211,8 @@ class Server:
                 prefix=self.prefix,
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
-                memory_cache=self.memory_cache,
+                attn_cache_size=self.attn_cache_size,
+                alloc_timeout=self.alloc_timeout,
                 throughput=self.throughput,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
@@ -310,7 +311,8 @@ class ModuleContainer(threading.Thread):
         prefix: str,
         converted_model_name_or_path: str,
         block_config: BloomConfig,
-        memory_cache: MemoryCache,
+        attn_cache_size: int,
+        alloc_timeout: float,
         throughput: float,
         block_indices: List[int],
         min_batch_size: int,
@@ -339,8 +341,9 @@ class ModuleContainer(threading.Thread):
         joining_announcer.start()
         logger.info(f"Announced that blocks {block_indices} are joining")
 
+        memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
+        blocks = {}
         try:
-            blocks = {}
             for module_uid, block_index in zip(module_uids, block_indices):
                 block = load_pretrained_block(
                     converted_model_name_or_path,
@@ -380,6 +383,10 @@ class ModuleContainer(threading.Thread):
                     max_batch_size=max_batch_size,
                 )
         except:
+            logger.debug("Shutting down backends")
+            for backend in blocks.values():
+                backend.shutdown()
+
             joining_announcer.stop.set()
             joining_announcer.join()
             declare_active_modules(
@@ -563,7 +570,7 @@ class ModuleAnnouncerThread(threading.Thread):
         self.stop = threading.Event()
 
     def run(self) -> None:
-        while True:
+        for iter_no in itertools.count():
             declare_active_modules(
                 self.dht,
                 self.module_uids,
@@ -571,5 +578,10 @@ class ModuleAnnouncerThread(threading.Thread):
                 state=self.state,
                 throughput=self.throughput,
             )
+            if iter_no == 0 and self.state == ServerState.JOINING:
+                logger.info(
+                    f"Please ensure that your server is reachable. "
+                    f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}"
+                )
             if self.stop.wait(self.update_period):
                 break