5
0
Эх сурвалжийг харах

make attention cache wait until memory is freed

justheuristic 3 жил өмнө
parent
commit
c8d22f8bbb
1 өөрчлөгдсөн 27 нэмэгдсэн , 11 устгасан
  1. 27 11
      src/server/cache.py

+ 27 - 11
src/server/cache.py

@@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 
 """
+import asyncio
 import contextlib
 import ctypes
 import multiprocessing as mp
 import os
+import time
 from typing import AsyncContextManager, Dict, Optional, Union
 
 import hivemind
@@ -36,6 +38,8 @@ class MemoryCache:
 
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
         self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._lock_acquire_memory = mp.Lock()
+        self._memory_freed_event = mp.Event()
 
     @property
     def current_size_bytes(self) -> int:
@@ -67,19 +71,17 @@ class MemoryCache:
         assert descr.device is None and descr
         allocated_handle = None
         allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        loop = asyncio.get_event_loop()
         try:
-            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+            async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    raise AllocationFailed(
-                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
-                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
-                    )
-
-                allocated_handle = int(self.handle_counter)
-                self.current_size_bytes += allocated_size_bytes
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pending_messages.value += 1
-                self._pipe_send.send((allocated_handle, descr))
+                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
+                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                    allocated_handle = int(self.handle_counter)
+                    self.current_size_bytes += allocated_size_bytes
+                    self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, descr))
 
             yield allocated_handle
         finally:
@@ -88,6 +90,20 @@ class MemoryCache:
                     self._pending_messages.value += 1
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self.current_size_bytes -= allocated_size_bytes
+                self._memory_freed_event.set()
+
+    def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
+        # note: this function should only be called inside lock_metadata!
+        if allocated_size_bytes > self.max_size_bytes:
+            raise AllocationFailed(
+                f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
+                f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated"
+            )
+        deadline = None if timeout is None else time.perf_counter() + timeout
+        while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+            if not self._memory_freed_event.wait(deadline - time.perf_counter() if timeout is not None else None):
+                raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
+            self._memory_freed_event.clear()
 
     @contextlib.contextmanager
     def use_cache(self, handle: Handle) -> torch.Tensor: