Browse Source

Merge branch 'main' into priority-tasks

justheuristic 2 years ago
parent
commit
1c456f4f46
2 changed files with 31 additions and 14 deletions
  1. 1 0
      src/client/sequence_manager.py
  2. 30 14
      src/server/cache.py

+ 1 - 0
src/client/sequence_manager.py

@@ -141,6 +141,7 @@ class RemoteSequenceManager:
                         stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
                     )
                     self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                    break
                 except Exception as e:
                     retries += 1
                     if retries >= self.max_retries:

+ 30 - 14
src/server/cache.py

@@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
 For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
 
 """
+import asyncio
 import contextlib
 import ctypes
 import multiprocessing as mp
 import os
+import time
 from typing import AsyncContextManager, Dict, Optional, Union
 
 import hivemind
@@ -27,7 +29,7 @@ class MemoryCache:
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.device = device
-        self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+        self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
@@ -36,6 +38,8 @@ class MemoryCache:
 
         self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
         self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._lock_acquire_memory = mp.Lock()
+        self._memory_freed_event = mp.Event()
 
     @property
     def current_size_bytes(self) -> int:
@@ -67,27 +71,39 @@ class MemoryCache:
         assert descr.device is None and descr
         allocated_handle = None
         allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        loop = asyncio.get_event_loop()
         try:
-            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+            async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    raise AllocationFailed(
-                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
-                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
-                    )
-
-                allocated_handle = int(self.handle_counter)
-                self.current_size_bytes += allocated_size_bytes
-                self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                self._pending_messages.value += 1
-                self._pipe_send.send((allocated_handle, descr))
+                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
+                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+                    allocated_handle = int(self.handle_counter)
+                    self.current_size_bytes += allocated_size_bytes
+                    self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, descr))
 
             yield allocated_handle
         finally:
             if allocated_handle is not None:
-                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
                     self._pending_messages.value += 1
                     self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
                     self.current_size_bytes -= allocated_size_bytes
+                self._memory_freed_event.set()
+
+    def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
+        # note: this function should only be called inside _lock_acquire_memory!
+        if allocated_size_bytes > self.max_size_bytes:
+            raise AllocationFailed(
+                f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
+            )
+        deadline = None if timeout is None else time.perf_counter() + timeout
+        while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+            remaining_time = deadline - time.perf_counter() if timeout is not None else None
+            if not self._memory_freed_event.wait(remaining_time):
+                raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
+            self._memory_freed_event.clear()
 
     @contextlib.contextmanager
     def use_cache(self, handle: Handle) -> torch.Tensor:
@@ -100,7 +116,7 @@ class MemoryCache:
         assert os.getpid() == self.runtime_pid
         # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
 
-        with self.lock_metadata:
+        with self._lock_metadata:
             if self._allocated_tensors is None:
                 self._allocated_tensors = {}