Browse Source

make attention cache wait until memory is freed

justheuristic 3 years ago
parent
commit
c8d22f8bbb
1 changed files with 27 additions and 11 deletions
  1. 27 11
      src/server/cache.py

+ 27 - 11
src/server/cache.py

@@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 
 
 """
 """
+import asyncio
 import contextlib
 import contextlib
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
+import time
 from typing import AsyncContextManager, Dict, Optional, Union
 from typing import AsyncContextManager, Dict, Optional, Union
 
 
 import hivemind
 import hivemind
@@ -36,6 +38,8 @@ class MemoryCache:
 
 
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
         self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
         self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._lock_acquire_memory = mp.Lock()
+        self._memory_freed_event = mp.Event()
 
 
     @property
     @property
     def current_size_bytes(self) -> int:
     def current_size_bytes(self) -> int:
@@ -67,19 +71,17 @@ class MemoryCache:
         assert descr.device is None and descr
         assert descr.device is None and descr
         allocated_handle = None
         allocated_handle = None
         allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
         allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        loop = asyncio.get_event_loop()
         try:
         try:
-            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+            async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    raise AllocationFailed(
-                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
-                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
-                    )
-
-                allocated_handle = int(self.handle_counter)
-                self.current_size_bytes += allocated_size_bytes
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pending_messages.value += 1
-                self._pipe_send.send((allocated_handle, descr))
+                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
+                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                    allocated_handle = int(self.handle_counter)
+                    self.current_size_bytes += allocated_size_bytes
+                    self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, descr))
 
 
             yield allocated_handle
             yield allocated_handle
         finally:
         finally:
@@ -88,6 +90,20 @@ class MemoryCache:
                     self._pending_messages.value += 1
                     self._pending_messages.value += 1
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self.current_size_bytes -= allocated_size_bytes
                     self.current_size_bytes -= allocated_size_bytes
+                self._memory_freed_event.set()
+
+    def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
+        # note: this function should only be called inside lock_metadata!
+        if allocated_size_bytes > self.max_size_bytes:
+            raise AllocationFailed(
+                f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
+                f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated"
+            )
+        deadline = None if timeout is None else time.perf_counter() + timeout
+        while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+            if not self._memory_freed_event.wait(deadline - time.perf_counter() if timeout is not None else None):
+                raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
+            self._memory_freed_event.clear()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def use_cache(self, handle: Handle) -> torch.Tensor:
     def use_cache(self, handle: Handle) -> torch.Tensor: