|
@@ -12,12 +12,13 @@ import os
|
|
import time
|
|
import time
|
|
from typing import AsyncContextManager, Dict, Optional, Sequence
|
|
from typing import AsyncContextManager, Dict, Optional, Sequence
|
|
|
|
|
|
-import hivemind
|
|
|
|
|
|
+import async_timeout
|
|
import torch
|
|
import torch
|
|
-from hivemind.utils import TensorDescriptor, get_logger
|
|
|
|
|
|
+from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
|
|
|
|
|
|
from petals.data_structures import Handle
|
|
from petals.data_structures import Handle
|
|
from petals.utils.asyncio import shield_and_wait
|
|
from petals.utils.asyncio import shield_and_wait
|
|
|
|
+from petals.utils.misc import get_size_in_bytes
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -25,11 +26,12 @@ logger = get_logger(__name__)
|
|
class MemoryCache:
|
|
class MemoryCache:
|
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
|
|
|
|
|
- def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
|
|
|
|
|
|
+ def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
- self.alloc_timeout = alloc_timeout
|
|
|
|
|
|
+ self.max_alloc_timeout = max_alloc_timeout
|
|
self._lock_metadata = mp.Lock()
|
|
self._lock_metadata = mp.Lock()
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
|
+ self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
|
|
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
|
|
self.runtime_pid = os.getpid()
|
|
self.runtime_pid = os.getpid()
|
|
@@ -46,6 +48,14 @@ class MemoryCache:
|
|
def current_size_bytes(self, value: int):
|
|
def current_size_bytes(self, value: int):
|
|
self._current_size.value = value
|
|
self._current_size.value = value
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def enqueued_size_bytes(self) -> int:
|
|
|
|
+ return self._enqueued_size.value
|
|
|
|
+
|
|
|
|
+ @enqueued_size_bytes.setter
|
|
|
|
+ def enqueued_size_bytes(self, value: int):
|
|
|
|
+ self._enqueued_size.value = value
|
|
|
|
+
|
|
@property
|
|
@property
|
|
def bytes_left(self) -> int:
|
|
def bytes_left(self) -> int:
|
|
return self.max_size_bytes - self.current_size_bytes
|
|
return self.max_size_bytes - self.current_size_bytes
|
|
@@ -59,11 +69,14 @@ class MemoryCache:
|
|
self._handle_counter.value = value
|
|
self._handle_counter.value = value
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
@contextlib.asynccontextmanager
|
|
- async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
|
|
|
|
|
|
+ async def allocate_cache(
|
|
|
|
+ self, *descriptors: TensorDescriptor, timeout: float
|
|
|
|
+ ) -> AsyncContextManager[Sequence[Handle]]:
|
|
"""
|
|
"""
|
|
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
|
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
|
|
|
|
|
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
|
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
|
|
|
+ :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
|
|
|
|
|
|
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
|
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
|
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
|
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
|
@@ -73,6 +86,8 @@ class MemoryCache:
|
|
"""
|
|
"""
|
|
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
|
|
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
|
|
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
|
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
|
|
|
+ if self.max_alloc_timeout is not None:
|
|
|
|
+ timeout = min(timeout, self.max_alloc_timeout)
|
|
max_alloc_size = self.get_allocation_size(*descriptors)
|
|
max_alloc_size = self.get_allocation_size(*descriptors)
|
|
|
|
|
|
gib = 1024**3
|
|
gib = 1024**3
|
|
@@ -83,10 +98,10 @@ class MemoryCache:
|
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
)
|
|
)
|
|
|
|
|
|
- alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
|
|
|
|
|
|
+ alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
|
|
try:
|
|
try:
|
|
handles = await shield_and_wait(alloc_task)
|
|
handles = await shield_and_wait(alloc_task)
|
|
- logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
|
|
|
+ logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
|
|
yield handles
|
|
yield handles
|
|
finally:
|
|
finally:
|
|
self._free(max_alloc_size, alloc_task)
|
|
self._free(max_alloc_size, alloc_task)
|
|
@@ -96,28 +111,59 @@ class MemoryCache:
|
|
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
|
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
|
alloc_size_by_device = {}
|
|
alloc_size_by_device = {}
|
|
for descr in descriptors:
|
|
for descr in descriptors:
|
|
- tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
+ tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
|
|
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
|
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
|
return max(alloc_size_by_device.values())
|
|
return max(alloc_size_by_device.values())
|
|
|
|
|
|
- async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
|
|
|
|
|
|
+ async def _schedule_alloc(
|
|
|
|
+ self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
|
|
|
|
+ ) -> Sequence[Handle]:
|
|
"""
|
|
"""
|
|
This method should be called inside asyncio.shield() because:
|
|
This method should be called inside asyncio.shield() because:
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
"""
|
|
"""
|
|
|
|
+ try:
|
|
|
|
+ async with self._wait_for_free_memory(alloc_size, timeout):
|
|
|
|
+ with self._lock_metadata:
|
|
|
|
+ handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
|
|
+ self.current_size_bytes += alloc_size
|
|
|
|
+ self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
|
|
|
+ self._pipe_send.send((handles, descriptors))
|
|
|
|
+ return handles
|
|
|
|
+ except TimeoutError:
|
|
|
|
+ raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
|
|
|
|
|
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
|
+ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
|
|
|
|
+ start_time = time.perf_counter()
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
- async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
|
- if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
|
- await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
|
- with self._lock_metadata:
|
|
|
|
- handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
|
|
- self.current_size_bytes += alloc_size
|
|
|
|
- self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
|
|
|
- self._pipe_send.send((handles, descriptors))
|
|
|
|
- return handles
|
|
|
|
-
|
|
|
|
- def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
|
|
|
|
|
|
+
|
|
|
|
+ self.enqueued_size_bytes += alloc_size
|
|
|
|
+ allocated = False
|
|
|
|
+ try:
|
|
|
|
+ context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
|
|
|
|
+ # contextlib.AsyncExitStack() is used as a null context here
|
|
|
|
+ async with context_manager:
|
|
|
|
+ if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
|
|
|
|
+ raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
|
|
|
|
+ async with enter_asynchronously(self._lock_acquire_memory):
|
|
|
|
+ if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
|
+ if timeout == 0:
|
|
|
|
+ raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
|
|
|
|
+ elapsed_time = time.perf_counter() - start_time
|
|
|
|
+ remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
|
|
|
|
+ await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
|
|
|
|
+
|
|
|
|
+ allocated = True
|
|
|
|
+ self.enqueued_size_bytes -= alloc_size
|
|
|
|
+ yield
|
|
|
|
+ except asyncio.TimeoutError:
|
|
|
|
+ raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
|
|
|
|
+ finally:
|
|
|
|
+ if not allocated:
|
|
|
|
+ self.enqueued_size_bytes -= alloc_size
|
|
|
|
+
|
|
|
|
+ def _free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
if alloc_task.exception() is not None:
|
|
if alloc_task.exception() is not None:
|
|
return
|
|
return
|
|
handles = alloc_task.result()
|
|
handles = alloc_task.result()
|
|
@@ -133,9 +179,10 @@ class MemoryCache:
|
|
raise AllocationFailed(
|
|
raise AllocationFailed(
|
|
f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
|
|
f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
|
|
)
|
|
)
|
|
|
|
+ timeout = timeout if timeout != float("inf") else None
|
|
deadline = None if timeout is None else time.perf_counter() + timeout
|
|
deadline = None if timeout is None else time.perf_counter() + timeout
|
|
while self.current_size_bytes + allocated_size > self.max_size_bytes:
|
|
while self.current_size_bytes + allocated_size > self.max_size_bytes:
|
|
- remaining_time = deadline - time.perf_counter() if timeout is not None else None
|
|
|
|
|
|
+ remaining_time = None if timeout is None else deadline - time.perf_counter()
|
|
if not self._memory_freed_event.wait(remaining_time):
|
|
if not self._memory_freed_event.wait(remaining_time):
|
|
raise AllocationFailed(
|
|
raise AllocationFailed(
|
|
f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
|
|
f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
|