|
@@ -17,6 +17,7 @@ import torch
|
|
from hivemind.utils import TensorDescriptor, get_logger
|
|
from hivemind.utils import TensorDescriptor, get_logger
|
|
|
|
|
|
from petals.utils.asyncio import shield_and_wait
|
|
from petals.utils.asyncio import shield_and_wait
|
|
|
|
+from petals.utils.misc import get_size_in_bytes
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -26,9 +27,8 @@ Handle = int
|
|
class MemoryCache:
|
|
class MemoryCache:
|
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
|
|
|
|
|
- def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
|
|
|
|
|
|
+ def __init__(self, max_size_bytes: Optional[int]):
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
|
- self.alloc_timeout = alloc_timeout
|
|
|
|
self._lock_metadata = mp.Lock()
|
|
self._lock_metadata = mp.Lock()
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
@@ -60,11 +60,12 @@ class MemoryCache:
|
|
self._handle_counter.value = value
|
|
self._handle_counter.value = value
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
@contextlib.asynccontextmanager
|
|
- async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
|
|
|
|
|
|
+ async def allocate_cache(self, *descriptors: TensorDescriptor, timeout: Optional[float] = None) -> AsyncContextManager[Sequence[Handle]]:
|
|
"""
|
|
"""
|
|
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
|
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
|
|
|
|
|
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
|
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
|
|
|
+ :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
|
|
|
|
|
|
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
|
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
|
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
|
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
|
@@ -76,7 +77,7 @@ class MemoryCache:
|
|
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
|
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
|
max_alloc_size = self.get_allocation_size(*descriptors)
|
|
max_alloc_size = self.get_allocation_size(*descriptors)
|
|
|
|
|
|
- gib = 1024**3
|
|
|
|
|
|
+ gib = 1
|
|
cur_size, max_size = self.current_size_bytes, self.max_size_bytes
|
|
cur_size, max_size = self.current_size_bytes, self.max_size_bytes
|
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
|
logger.info(
|
|
logger.info(
|
|
@@ -84,24 +85,26 @@ class MemoryCache:
|
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
)
|
|
)
|
|
|
|
|
|
- alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
|
|
|
|
|
|
+ alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
|
|
try:
|
|
try:
|
|
handles = await shield_and_wait(alloc_task)
|
|
handles = await shield_and_wait(alloc_task)
|
|
- logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
|
|
|
+ logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)")
|
|
yield handles
|
|
yield handles
|
|
finally:
|
|
finally:
|
|
|
|
+ logger.info(f"rpc_inference.dealloc-began(size={max_alloc_size / gib:.2f} GiB)")
|
|
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
|
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
|
|
|
+ logger.info(f"rpc_inference.dealloc-done(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
|
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
|
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
|
alloc_size_by_device = {}
|
|
alloc_size_by_device = {}
|
|
for descr in descriptors:
|
|
for descr in descriptors:
|
|
- tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
+ tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
|
|
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
|
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
|
return max(alloc_size_by_device.values())
|
|
return max(alloc_size_by_device.values())
|
|
|
|
|
|
- async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
|
|
|
|
|
|
+ async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]) -> Sequence[Handle]:
|
|
"""
|
|
"""
|
|
This method should be called inside asyncio.shield() because:
|
|
This method should be called inside asyncio.shield() because:
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
@@ -110,7 +113,7 @@ class MemoryCache:
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
- await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
|
|
|
+ await loop.run_in_executor(None, self._wait_until_available, alloc_size, timeout)
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
self.current_size_bytes += alloc_size
|
|
self.current_size_bytes += alloc_size
|
|
@@ -124,7 +127,6 @@ class MemoryCache:
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
- _schedule_free() must finish freeing memory even in case of cancellation
|
|
- _schedule_free() must finish freeing memory even in case of cancellation
|
|
"""
|
|
"""
|
|
-
|
|
|
|
if alloc_task.exception() is not None:
|
|
if alloc_task.exception() is not None:
|
|
return
|
|
return
|
|
handles = alloc_task.result()
|
|
handles = alloc_task.result()
|