|
@@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
|
|
|
For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
|
|
|
|
|
|
"""
|
|
|
+import asyncio
|
|
|
import contextlib
|
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import os
|
|
|
+import time
|
|
|
from typing import AsyncContextManager, Dict, Optional, Union
|
|
|
|
|
|
import hivemind
|
|
@@ -27,7 +29,7 @@ class MemoryCache:
|
|
|
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
|
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
|
self.device = device
|
|
|
- self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
|
|
|
+ self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
|
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
|
|
@@ -36,6 +38,8 @@ class MemoryCache:
|
|
|
|
|
|
self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
|
|
|
self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
+ self._lock_acquire_memory = mp.Lock()
|
|
|
+ self._memory_freed_event = mp.Event()
|
|
|
|
|
|
@property
|
|
|
def current_size_bytes(self) -> int:
|
|
@@ -67,27 +71,39 @@ class MemoryCache:
|
|
|
assert descr.device is None and descr
|
|
|
allocated_handle = None
|
|
|
allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
try:
|
|
|
- async with hivemind.utils.enter_asynchronously(self.lock_metadata):
|
|
|
+ async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
|
|
|
- raise AllocationFailed(
|
|
|
- f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
|
|
|
- f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
|
|
|
- )
|
|
|
-
|
|
|
- allocated_handle = int(self.handle_counter)
|
|
|
- self.current_size_bytes += allocated_size_bytes
|
|
|
- self.handle_counter += 1 # note: this will eventually overflow and it is okay
|
|
|
- self._pending_messages.value += 1
|
|
|
- self._pipe_send.send((allocated_handle, descr))
|
|
|
+ await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
|
|
|
+ async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
+ allocated_handle = int(self.handle_counter)
|
|
|
+ self.current_size_bytes += allocated_size_bytes
|
|
|
+ self.handle_counter += 1 # note: this will eventually overflow and it is okay
|
|
|
+ self._pending_messages.value += 1
|
|
|
+ self._pipe_send.send((allocated_handle, descr))
|
|
|
|
|
|
yield allocated_handle
|
|
|
finally:
|
|
|
if allocated_handle is not None:
|
|
|
- async with hivemind.utils.enter_asynchronously(self.lock_metadata):
|
|
|
+ async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
self._pending_messages.value += 1
|
|
|
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
|
|
|
self.current_size_bytes -= allocated_size_bytes
|
|
|
+ self._memory_freed_event.set()
|
|
|
+
|
|
|
+ def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
|
|
|
+ # note: this function should only be called inside _lock_acquire_memory!
|
|
|
+ if allocated_size_bytes > self.max_size_bytes:
|
|
|
+ raise AllocationFailed(
|
|
|
+ f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
|
|
|
+ )
|
|
|
+ deadline = None if timeout is None else time.perf_counter() + timeout
|
|
|
+ while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
|
|
|
+ remaining_time = deadline - time.perf_counter() if timeout is not None else None
|
|
|
+ if not self._memory_freed_event.wait(remaining_time):
|
|
|
+ raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
|
|
|
+ self._memory_freed_event.clear()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def use_cache(self, handle: Handle) -> torch.Tensor:
|
|
@@ -100,7 +116,7 @@ class MemoryCache:
|
|
|
assert os.getpid() == self.runtime_pid
|
|
|
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
|
|
|
|
|
- with self.lock_metadata:
|
|
|
+ with self._lock_metadata:
|
|
|
if self._allocated_tensors is None:
|
|
|
self._allocated_tensors = {}
|
|
|
|