|
@@ -90,7 +90,7 @@ class MemoryCache:
|
|
|
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
yield handles
|
|
|
finally:
|
|
|
- await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
|
|
+ self._free(max_alloc_size, alloc_task)
|
|
|
|
|
|
@staticmethod
|
|
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
|
@@ -111,25 +111,19 @@ class MemoryCache:
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
- async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
+ with self._lock_metadata:
|
|
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
|
self.current_size_bytes += alloc_size
|
|
|
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
|
|
self._pipe_send.send((handles, descriptors))
|
|
|
return handles
|
|
|
|
|
|
- async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
|
- """
|
|
|
- This method should be called inside asyncio.shield() because:
|
|
|
- - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
|
- - _schedule_free() must finish freeing memory even in case of cancellation
|
|
|
- """
|
|
|
-
|
|
|
+ def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
|
|
|
if alloc_task.exception() is not None:
|
|
|
return
|
|
|
handles = alloc_task.result()
|
|
|
|
|
|
- async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
+ with self._lock_metadata:
|
|
|
self._pipe_send.send((handles, None)) # signal runtime to free these handles
|
|
|
self.current_size_bytes -= alloc_size
|
|
|
self._memory_freed_event.set()
|
|
@@ -160,22 +154,21 @@ class MemoryCache:
|
|
|
assert os.getpid() == self.runtime_pid
|
|
|
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
|
|
|
|
|
- with self._lock_metadata:
|
|
|
- # read creation/deletion requests from connection handlers
|
|
|
- while self._pipe_recv.poll():
|
|
|
- recv_handles, recv_data = self._pipe_recv.recv()
|
|
|
- if recv_data is not None: # create new tensors
|
|
|
- assert len(recv_handles) == len(recv_data)
|
|
|
- for handle, descr in zip(recv_handles, recv_data):
|
|
|
- self._allocated_tensors[handle] = descr.make_zeros()
|
|
|
- assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
|
|
- else: # delete tensors by handle
|
|
|
- for handle in recv_handles:
|
|
|
- if handle not in self._allocated_tensors:
|
|
|
- logger.warning(
|
|
|
- f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
|
|
|
- )
|
|
|
- self._allocated_tensors.pop(handle, None)
|
|
|
+ # read creation/deletion requests from connection handlers
|
|
|
+ while self._pipe_recv.poll():
|
|
|
+ recv_handles, recv_data = self._pipe_recv.recv()
|
|
|
+ if recv_data is not None: # create new tensors
|
|
|
+ assert len(recv_handles) == len(recv_data)
|
|
|
+ for handle, descr in zip(recv_handles, recv_data):
|
|
|
+ self._allocated_tensors[handle] = descr.make_zeros()
|
|
|
+ assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
|
|
+ else: # delete tensors by handle
|
|
|
+ for handle in recv_handles:
|
|
|
+ if handle not in self._allocated_tensors:
|
|
|
+ logger.warning(
|
|
|
+ f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
|
|
|
+ )
|
|
|
+ self._allocated_tensors.pop(handle, None)
|
|
|
yield tuple(self._allocated_tensors[handle] for handle in handles)
|
|
|
|
|
|
|