Explorar el Código

memory cache for attention KVs

justheuristic hace 3 años
padre
commit
e2e9d0e94c
Se han modificado 1 ficheros con 113 adiciones y 27 borrados
  1. 113 27
      src/node/cache.py

+ 113 - 27
src/node/cache.py

@@ -1,43 +1,129 @@
+"""
+A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
+
+For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
+
+TODO In future, one could modify cache to implement, among other things,
+- in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
+- allocate cache as one contigous buffer to avoid fragmentation
+- quantize cached values using bitsandbytes
+- LRU offloading from gpu to ram
+
+"""
 import contextlib
 import ctypes
 import multiprocessing as mp
-from typing import Dict, Tuple
+import os
+from typing import Dict, Optional, Union
 
+import hivemind
 import torch
+from hivemind.utils import TensorDescriptor, get_logger
+
+logger = get_logger(__file__)
+
+Handle = int
 
 
 class MemoryCache:
-    lock: mp.Lock
-    runtime_pid: int
-    handle_counter: mp.Value[ctypes.c_uint64]
-    current_size: mp.Value[ctypes.c_uint64]
-    _runtime_data: Dict[int, SomeKindOfTupleWithTensors]  # workaround for now, while we are on CPU
+    """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
+
+    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
+        self.device = device
+        self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+        self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
+        self._handle_counter = mp.Value(ctypes.c_uint64, 0, lock=False)
+        self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
+        self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
+        self.runtime_pid = os.getpid()
+
+        self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
+        self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+
+    @property
+    def current_size_bytes(self) -> int:
+        return self._current_size.value
+
+    @current_size_bytes.setter
+    def current_size_bytes(self, value: int):
+        self._current_size.value = value
+
+    @property
+    def handle_counter(self) -> int:
+        return self._handle_counter.value
+
+    @handle_counter.setter
+    def handle_counter(self, value: int):
+        self._handle_counter.value = value
 
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> Optional[int]:
+    async def allocate_cache(self, descr: TensorDescriptor) -> Handle:
         """
-        Allocate buffers for attention cache on the compute device, return a unique handle;
-        This function should be called by connection handler processes, may be called concurrently
+        Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
+
+        :param descr: allocate a tensor of this size, dtype, etc
+
+        :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
+        Furthermore, it can be called concurrently with at most one use_cache call in runtime.
         """
-        assert os.getpid() != self.runtime_pid
+        assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
+        assert descr.device is None and descr
+        allocated_handle = None
+        allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
         try:
-            async with acquire_asynchronously(self.lock):
-                check_and_update_size(current_size, size, dtype)
-                if enough_space:
-                    self.handle_counter.value += 1
-                    handle = int(self.handle_counter.value)
-                    # note: you cannot allocate data here because this is
-                    TODO_SOMEHOW_COMUNICATE_WITH_RUNTIME_TO_CREATE_THE_RIGHT_DATA
-            yield handle
+            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+                    raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
+                                           f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated.")
+
+                allocated_handle = int(self.handle_counter)
+                self.current_size_bytes += allocated_size_bytes
+                self.handle_counter += 1   # note: this will eventually overflow and it is okay
+                self._pending_messages.value += 1
+                self._pipe_send.send((allocated_handle, descr))
+
+            yield allocated_handle
         finally:
-            todo_deallocate(self, handle)
-            # ^-- this should NOT move any data. But it may mark data for movement during next allocation
-            self.data.pop(handle, None);
+            if allocated_handle is not None:
+                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
+                    self.current_size_bytes -= allocated_size_bytes
 
-    def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
-        """Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
+    @contextlib.contextmanager
+    def use_cache(self, handle: Handle) -> torch.Tensor:
+        """
+        Return a tensor that was previously allocated with try_allocate_cache,
+
+        :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
+        However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
+        """
         assert os.getpid() == self.runtime_pid
-        with self.lock:
-            if first_time:
-                allocate_stuff(self._runtime_data)
-            yield self.data[handle]
+        # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
+
+        with self.lock_metadata:
+            if self._allocated_tensors is None:
+                self._allocated_tensors = {}
+
+            # read creation/deletion requests from connection handlers
+            for i in range(int(self._pending_messages.value)):
+                recv_handle, recv_data = self._pipe_recv.recv()
+                self._pending_messages.value -= 1
+                if isinstance(recv_data, TensorDescriptor):
+                    self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
+                elif recv_data is None:
+                    if recv_handle not in self._allocated_tensors:
+                        logger.warning(
+                            f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
+                        )
+                    self._allocated_tensors.pop(recv_handle, None)
+                else:
+                    logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
+
+        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+        yield self._allocated_tensors[handle]
+
+
+class AllocationFailed(Exception):
+    pass