Jelajahi Sumber

Make attention cache wait until memory is freed (#53)

Previously, attempting to allocate with MemoryCache that does not have enough space would throw AllocationFailed.

PR changes this behavior to the following:
- by default, wait until memory is freed by other tenants (FIFO)
- if could not allocate within timeout, throw AllocationFailed
- if allocated size is too big to fit even in empty cache, throw AllocationFailed

- [x] passes existing tests
- [x] passes manual load tests

p.s. if anyone wondered: using mp.Condition will not make the code simpler, their lock behavior is slightly different to what we need here

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 tahun lalu
induk
melakukan
f3984b192a
1 mengubah file dengan 30 tambahan dan 14 penghapusan
  1. 30 14
      src/server/cache.py

+ 30 - 14
src/server/cache.py

@@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 
 
 """
 """
+import asyncio
 import contextlib
 import contextlib
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
+import time
 from typing import AsyncContextManager, Dict, Optional, Union
 from typing import AsyncContextManager, Dict, Optional, Union
 
 
 import hivemind
 import hivemind
@@ -27,7 +29,7 @@ class MemoryCache:
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.device = device
         self.device = device
-        self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+        self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
         self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
@@ -36,6 +38,8 @@ class MemoryCache:
 
 
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
         self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
         self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._lock_acquire_memory = mp.Lock()
+        self._memory_freed_event = mp.Event()
 
 
     @property
     @property
     def current_size_bytes(self) -> int:
     def current_size_bytes(self) -> int:
@@ -67,27 +71,39 @@ class MemoryCache:
         assert descr.device is None and descr
         assert descr.device is None and descr
         allocated_handle = None
         allocated_handle = None
         allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
         allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        loop = asyncio.get_event_loop()
         try:
         try:
-            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+            async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    raise AllocationFailed(
-                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
-                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
-                    )
-
-                allocated_handle = int(self.handle_counter)
-                self.current_size_bytes += allocated_size_bytes
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pending_messages.value += 1
-                self._pipe_send.send((allocated_handle, descr))
+                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
+                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+                    allocated_handle = int(self.handle_counter)
+                    self.current_size_bytes += allocated_size_bytes
+                    self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, descr))
 
 
             yield allocated_handle
             yield allocated_handle
         finally:
         finally:
             if allocated_handle is not None:
             if allocated_handle is not None:
-                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
                     self._pending_messages.value += 1
                     self._pending_messages.value += 1
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self.current_size_bytes -= allocated_size_bytes
                     self.current_size_bytes -= allocated_size_bytes
+                self._memory_freed_event.set()
+
+    def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
+        # note: this function should only be called inside _lock_acquire_memory!
+        if allocated_size_bytes > self.max_size_bytes:
+            raise AllocationFailed(
+                f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
+            )
+        deadline = None if timeout is None else time.perf_counter() + timeout
+        while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+            remaining_time = deadline - time.perf_counter() if timeout is not None else None
+            if not self._memory_freed_event.wait(remaining_time):
+                raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
+            self._memory_freed_event.clear()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def use_cache(self, handle: Handle) -> torch.Tensor:
     def use_cache(self, handle: Handle) -> torch.Tensor:
@@ -100,7 +116,7 @@ class MemoryCache:
         assert os.getpid() == self.runtime_pid
         assert os.getpid() == self.runtime_pid
         # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
         # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
 
 
-        with self.lock_metadata:
+        with self._lock_metadata:
             if self._allocated_tensors is None:
             if self._allocated_tensors is None:
                 self._allocated_tensors = {}
                 self._allocated_tensors = {}