|
@@ -26,8 +26,9 @@ Handle = int
|
|
|
class MemoryCache:
|
|
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
|
|
|
|
|
- def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
|
|
|
+ def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
|
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
|
+ self.alloc_timeout = alloc_timeout
|
|
|
self.device = device
|
|
|
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
|
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
@@ -75,7 +76,7 @@ class MemoryCache:
|
|
|
try:
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
|
|
|
- await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
|
|
|
+ await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes, timeout=self.alloc_timeout)
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
allocated_handle = int(self.handle_counter)
|
|
|
self.current_size_bytes += allocated_size_bytes
|