|
@@ -16,6 +16,8 @@ import hivemind
|
|
|
import torch
|
|
|
from hivemind.utils import TensorDescriptor, get_logger
|
|
|
|
|
|
+from petals.utils.asyncio import shield_and_wait
|
|
|
+
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
Handle = int
|
|
@@ -66,28 +68,46 @@ class MemoryCache:
|
|
|
"""
|
|
|
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
|
|
|
assert descr.device is None and descr
|
|
|
- allocated_handle = None
|
|
|
- allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
+
|
|
|
+ alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
+ alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
|
|
|
try:
|
|
|
- async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
- if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
|
|
|
- await loop.run_in_executor(
|
|
|
- None, self._wait_until_available, allocated_size_bytes, self.alloc_timeout
|
|
|
- )
|
|
|
- async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
- allocated_handle = int(self.handle_counter)
|
|
|
- self.current_size_bytes += allocated_size_bytes
|
|
|
- self.handle_counter += 1 # note: this will eventually overflow and it is okay
|
|
|
- self._pipe_send.send((allocated_handle, descr))
|
|
|
-
|
|
|
- yield allocated_handle
|
|
|
+ yield await shield_and_wait(alloc_task)
|
|
|
finally:
|
|
|
- if allocated_handle is not None:
|
|
|
- async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
- self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
|
|
|
- self.current_size_bytes -= allocated_size_bytes
|
|
|
- self._memory_freed_event.set()
|
|
|
+ await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
|
|
|
+
|
|
|
+ async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
|
|
|
+ """
|
|
|
+ This method should be called inside asyncio.shield() because:
|
|
|
+ - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
|
+ """
|
|
|
+
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
+ if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
+ await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
+ async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
+ handle = int(self.handle_counter)
|
|
|
+ self.current_size_bytes += alloc_size
|
|
|
+ self.handle_counter += 1 # note: this will eventually overflow and it is okay
|
|
|
+ self._pipe_send.send((handle, descr))
|
|
|
+ return handle
|
|
|
+
|
|
|
+ async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
|
+ """
|
|
|
+ This method should be called inside asyncio.shield() because:
|
|
|
+ - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
|
+ - _schedule_free() must finish freeing memory even in case of cancellation
|
|
|
+ """
|
|
|
+
|
|
|
+ if alloc_task.exception() is not None:
|
|
|
+ return
|
|
|
+ handle = alloc_task.result()
|
|
|
+
|
|
|
+ async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
+ self._pipe_send.send((handle, None)) # signal runtime to free that handle
|
|
|
+ self.current_size_bytes -= alloc_size
|
|
|
+ self._memory_freed_event.set()
|
|
|
|
|
|
def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
|
|
|
# note: this function should only be called inside _lock_acquire_memory!
|