5
0
Эх сурвалжийг харах

Rewrite MemoryCache alloc_timeout logic (#434)

-    rpc_inference: server will now accept allocation timeout from user, defaults to no timeout
-    bugfix: inference timeout is now measured from the moment the request is received
    -    previously, you would have to wait for your timeout plus the time it takes to sort through the queue (other users' timeout)
    -    now, you get AllocationFailed if you had to wait for over (timeout) seconds - regardless of other users
-    a request for inference with no timeout will now fail instantly if there is not enough memory available
-    dtype number of bytes is now correctly determined for int, bool & other types


---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic 2 жил өмнө
parent
commit
c08d09c4d3

+ 3 - 3
src/petals/cli/run_server.py

@@ -96,9 +96,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--alloc_timeout', type=float, default=1,
-                        help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
-                             'before rejecting the request')
+    parser.add_argument('--max_alloc_timeout', type=float, default=600,
+                        help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
+                             " before rejecting the request")
     parser.add_argument('--revision', type=str, default=None,
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")

+ 3 - 3
src/petals/server/backend.py

@@ -16,7 +16,7 @@ from transformers import PretrainedConfig
 from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
-from petals.utils.misc import is_dummy
+from petals.utils.misc import get_size_in_bytes, is_dummy
 
 logger = get_logger(__name__)
 
@@ -63,7 +63,7 @@ class TransformerBackend(ModuleBackend):
         )
 
         self.dtype = backend_dtype
-        self.dtype_bytes = torch.finfo(self.dtype).bits // 8
+        self.dtype_bytes = get_size_in_bytes(self.dtype)
         self.shard_num_heads = []
         for shard in self.module.module_shards:
             for submodule in shard.modules():
@@ -83,7 +83,7 @@ class TransformerBackend(ModuleBackend):
 
         self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
         for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
-            self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
+            self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
 
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""

+ 3 - 2
src/petals/server/block_utils.py

@@ -5,6 +5,7 @@ from accelerate import init_empty_weights
 from transformers import PretrainedConfig
 
 from petals.utils.convert_block import QuantType
+from petals.utils.misc import get_size_in_bytes
 
 
 def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@@ -37,7 +38,7 @@ def get_block_size(
     if location == "memory":
         if quant_type == QuantType.NONE:
             dtype = resolve_block_dtype(config, dtype)
-            bytes_per_value = torch.finfo(dtype).bits // 8
+            bytes_per_value = get_size_in_bytes(dtype)
         elif quant_type == QuantType.INT8:
             bytes_per_value = 1
         elif quant_type == QuantType.NF4:
@@ -46,6 +47,6 @@ def get_block_size(
             raise ValueError(f"Unsupported quant_type={quant_type}")
     elif location == "disk":
         dtype = resolve_block_dtype(config, "auto")
-        bytes_per_value = torch.finfo(dtype).bits // 8
+        bytes_per_value = get_size_in_bytes(dtype)
 
     return round(n_params * bytes_per_value * (1 + eps))

+ 11 - 3
src/petals/server/handler.py

@@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
+                alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
                 args_structure = metadata.get("args_structure")
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
@@ -166,7 +167,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
 
-                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
+                async with self._allocate_cache(
+                    requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
+                ) as cache_handles:
                     background_tasks = set()
                     async for output_tensors, can_push in iterate_rpc_inference(
                         requested_uids=requested_uids,
@@ -528,14 +531,19 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     @contextlib.asynccontextmanager
     async def _allocate_cache(
-        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+        self,
+        backends: Sequence[TransformerBackend],
+        *,
+        batch_size: int,
+        max_length: int,
+        timeout: Optional[float],
     ) -> Sequence[Sequence[Handle]]:
         """
         Allocate memory cache for all transformer blocks, return cache handle
         :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
         """
         descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
-        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
+        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
             yield nested_pack(handles, descriptors)
 
     def _log_request(

+ 68 - 21
src/petals/server/memory_cache.py

@@ -12,12 +12,13 @@ import os
 import time
 from typing import AsyncContextManager, Dict, Optional, Sequence
 
-import hivemind
+import async_timeout
 import torch
-from hivemind.utils import TensorDescriptor, get_logger
+from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
 
 from petals.data_structures import Handle
 from petals.utils.asyncio import shield_and_wait
+from petals.utils.misc import get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -25,11 +26,12 @@ logger = get_logger(__name__)
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
-    def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
+    def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
-        self.alloc_timeout = alloc_timeout
+        self.max_alloc_timeout = max_alloc_timeout
         self._lock_metadata = mp.Lock()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
         self.runtime_pid = os.getpid()
@@ -46,6 +48,14 @@ class MemoryCache:
     def current_size_bytes(self, value: int):
         self._current_size.value = value
 
+    @property
+    def enqueued_size_bytes(self) -> int:
+        return self._enqueued_size.value
+
+    @enqueued_size_bytes.setter
+    def enqueued_size_bytes(self, value: int):
+        self._enqueued_size.value = value
+
     @property
     def bytes_left(self) -> int:
         return self.max_size_bytes - self.current_size_bytes
@@ -59,11 +69,14 @@ class MemoryCache:
         self._handle_counter.value = value
 
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
+    async def allocate_cache(
+        self, *descriptors: TensorDescriptor, timeout: float
+    ) -> AsyncContextManager[Sequence[Handle]]:
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 
         :param descriptors: one or more tensors tensor of this size, dtype, etc
+        :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
 
         :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
           if not, it will count maximum tensor allocation across devices for the purposes of size limit
@@ -73,6 +86,8 @@ class MemoryCache:
         """
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
         assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
+        if self.max_alloc_timeout is not None:
+            timeout = min(timeout, self.max_alloc_timeout)
         max_alloc_size = self.get_allocation_size(*descriptors)
 
         gib = 1024**3
@@ -83,10 +98,10 @@ class MemoryCache:
             f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
         )
 
-        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
+        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
         try:
             handles = await shield_and_wait(alloc_task)
-            logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
+            logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
             yield handles
         finally:
             self._free(max_alloc_size, alloc_task)
@@ -96,28 +111,59 @@ class MemoryCache:
         """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
         alloc_size_by_device = {}
         for descr in descriptors:
-            tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+            tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
             alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
         return max(alloc_size_by_device.values())
 
-    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
+    async def _schedule_alloc(
+        self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
+    ) -> Sequence[Handle]:
         """
         This method should be called inside asyncio.shield() because:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
         """
+        try:
+            async with self._wait_for_free_memory(alloc_size, timeout):
+                with self._lock_metadata:
+                    handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
+                    self.current_size_bytes += alloc_size
+                    self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
+                    self._pipe_send.send((handles, descriptors))
+                    return handles
+        except TimeoutError:
+            raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
 
+    @contextlib.asynccontextmanager
+    async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
+        start_time = time.perf_counter()
         loop = asyncio.get_event_loop()
-        async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
-            if self.current_size_bytes + alloc_size > self.max_size_bytes:
-                await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
-            with self._lock_metadata:
-                handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
-                self.current_size_bytes += alloc_size
-                self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
-                self._pipe_send.send((handles, descriptors))
-                return handles
-
-    def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
+
+        self.enqueued_size_bytes += alloc_size
+        allocated = False
+        try:
+            context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
+            # contextlib.AsyncExitStack() is used as a null context here
+            async with context_manager:
+                if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
+                    raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
+                async with enter_asynchronously(self._lock_acquire_memory):
+                    if self.current_size_bytes + alloc_size > self.max_size_bytes:
+                        if timeout == 0:
+                            raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
+                        elapsed_time = time.perf_counter() - start_time
+                        remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
+                        await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
+
+                allocated = True
+                self.enqueued_size_bytes -= alloc_size
+                yield
+        except asyncio.TimeoutError:
+            raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
+        finally:
+            if not allocated:
+                self.enqueued_size_bytes -= alloc_size
+
+    def _free(self, alloc_size: int, alloc_task: asyncio.Task):
         if alloc_task.exception() is not None:
             return
         handles = alloc_task.result()
@@ -133,9 +179,10 @@ class MemoryCache:
             raise AllocationFailed(
                 f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
             )
+        timeout = timeout if timeout != float("inf") else None
         deadline = None if timeout is None else time.perf_counter() + timeout
         while self.current_size_bytes + allocated_size > self.max_size_bytes:
-            remaining_time = deadline - time.perf_counter() if timeout is not None else None
+            remaining_time = None if timeout is None else deadline - time.perf_counter()
             if not self._memory_freed_event.wait(remaining_time):
                 raise AllocationFailed(
                     f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"

+ 8 - 8
src/petals/server/server.py

@@ -31,6 +31,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.dht import declare_active_modules, get_remote_module_infos
+from petals.utils.misc import get_size_in_bytes
 from petals.utils.ping import PingAggregator
 from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo
@@ -59,12 +60,12 @@ class Server:
         min_batch_size: int = 1,
         max_batch_size: Optional[int] = None,
         max_chunk_size_bytes: int = 256 * 1024 * 1024,
+        max_alloc_timeout: float = 600,
         attn_cache_tokens: Optional[int] = None,
         torch_dtype: str = "auto",
         revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
-        alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
@@ -185,13 +186,14 @@ class Server:
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.max_chunk_size_bytes = max_chunk_size_bytes
+        self.max_alloc_timeout = max_alloc_timeout
 
         # For attention cache in GPU or RAM
         if attn_cache_tokens is None:
             attn_cache_tokens = 32768 if is_multiquery_attn else 8192
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block //= self.block_config.num_key_value_groups
-        self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
+        self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
 
         # For disk cache
         self.cache_dir = cache_dir
@@ -217,8 +219,6 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
-        self.alloc_timeout = alloc_timeout
-
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput_info = get_server_throughput(
@@ -311,13 +311,13 @@ class Server:
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
-                alloc_timeout=self.alloc_timeout,
                 server_info=self.server_info,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
                 max_batch_size=self.max_batch_size,
                 max_chunk_size_bytes=self.max_chunk_size_bytes,
+                max_alloc_timeout=self.max_alloc_timeout,
                 inference_max_length=self.inference_max_length,
                 torch_dtype=self.torch_dtype,
                 cache_dir=self.cache_dir,
@@ -413,12 +413,12 @@ class ModuleContainer(threading.Thread):
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
-        alloc_timeout: float,
         server_info: ServerInfo,
         block_indices: List[int],
         min_batch_size: int,
         max_batch_size: int,
         max_chunk_size_bytes: int,
+        max_alloc_timeout: float,
         torch_dtype: torch.dtype,
         cache_dir: str,
         max_disk_space: int,
@@ -434,7 +434,7 @@ class ModuleContainer(threading.Thread):
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
-        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
+        memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
 
         server_info.state = ServerState.JOINING
         dht_announcer = ModuleAnnouncerThread(
@@ -663,7 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
         self.server_info = server_info
         self.memory_cache = memory_cache
 
-        self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
+        self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
         self.bytes_per_token //= block_config.num_key_value_groups
 
         self.update_period = update_period

+ 10 - 0
src/petals/utils/misc.py

@@ -9,6 +9,16 @@ def is_dummy(tensor: torch.Tensor) -> bool:
     return tensor.numel() == 0
 
 
+SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}
+
+
+def get_size_in_bytes(dtype: torch.dtype) -> int:
+    if dtype in SPECIAL_DTYPE_SIZES:
+        return SPECIAL_DTYPE_SIZES[dtype]
+    get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
+    return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
+
+
 def docstring_from(source):
     def add_docstring(dest):
         dest.__doc__ = source.__doc__

+ 2 - 1
src/petals/utils/peft.py

@@ -20,6 +20,7 @@ from transformers.utils import get_file_from_repo
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_block import QuantType
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.misc import get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -285,5 +286,5 @@ def estimate_adapter_memory_per_block(
                 block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
             )
         adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
-    bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
+    bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
     return adapter_parameters * bytes_per_parameter

+ 184 - 0
tests/test_cache.py

@@ -0,0 +1,184 @@
+import asyncio
+import multiprocessing as mp
+import random
+import time
+from typing import Optional
+
+import pytest
+import pytest_asyncio  # make sure the module exists; otherwise the test will be skipped
+import torch
+from hivemind import TensorDescriptor
+
+from petals.server.memory_cache import AllocationFailed, MemoryCache
+from petals.utils.misc import get_size_in_bytes
+
+
+def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
+    if dtype is None:
+        dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
+    elem_size_bytes = get_size_in_bytes(dtype)
+    descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
+    return descr
+
+
+@pytest.mark.asyncio
+async def test_cache_timeout():
+    cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
+    cache.runtime_pid += 1  # pretend we're another process
+    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
+        pass
+
+    async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
+        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
+            async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
+                t_start = time.perf_counter()
+                with pytest.raises(AllocationFailed):
+                    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
+                        pass
+                assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
+                async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
+                    pass
+
+                t_start = time.perf_counter()
+                with pytest.raises(AllocationFailed):
+                    async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0):  # exceeds max timeout
+                        pass
+                assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
+
+            # test memory allocation when another task frees the memory
+            async def _klog_the_cache():
+                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
+                    pass
+
+            large_alloc_task = asyncio.create_task(_klog_the_cache())
+
+            t_start = time.perf_counter()
+            await asyncio.sleep(0.05)  # wait for large alloc to enqueue
+            async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):  # exceeds max timeout
+                pass  # this memory should allocate once the background task clears the queue
+            assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
+            with pytest.raises(AllocationFailed):
+                await large_alloc_task
+
+            # test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
+            large_alloc_task = asyncio.create_task(_klog_the_cache())
+            t_start = time.perf_counter()
+            await asyncio.sleep(0.05)  # wait for large alloc to enqueue
+            with pytest.raises(AllocationFailed):
+                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
+                    pass  # this memory should allocate once the background task clears the queue
+            assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
+            with pytest.raises(AllocationFailed):
+                await large_alloc_task
+
+
+@pytest.mark.asyncio
+async def test_unlimited_timeout():
+    cache = MemoryCache(max_size_bytes=1024)
+    cache.runtime_pid += 1  # pretend we're another process
+    t_start = time.perf_counter()
+
+    async def _klog_the_cache():
+        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
+            await asyncio.sleep(0.5)
+
+    alloc_task = asyncio.create_task(_klog_the_cache())
+    await asyncio.sleep(0.1)
+    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
+        await alloc_task
+    assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
+
+
+@pytest.mark.asyncio
+async def test_cache_usage():
+    cache = MemoryCache(max_size_bytes=2048)
+    alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
+    pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
+    with pytest.raises(AssertionError):
+        async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
+            pass  # fails because cache must be allocated from another process
+
+    descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8))  # 768 bytes
+    descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64))  # 8 bytes
+    descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool))  # 33 bytes
+    descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64))  # 0 bytes
+    descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16))  # 1536 bytes
+    descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8))  # 1792 bytes
+
+    async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
+        loop = asyncio.get_event_loop()
+        async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
+            pipe_sender.send(handles)
+            await loop.run_in_executor(None, dealloc_event.wait)
+
+    async def _allocate_af():
+        alloc_event.wait()
+        allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
+        await allocate_a_task
+        allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f))  # klogs the cache
+        await allocate_f_task
+
+    alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
+    alloc_process1.start()
+
+    async def _allocate_bcde():
+        alloc_event.wait()
+        await asyncio.sleep(0.1)  # ensure that the other tensor is always allocated (and sent through pipe) first
+        allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
+        allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit
+        await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
+
+    alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
+    alloc_process2.start()
+    assert cache.current_size_bytes == 0
+    alloc_event.set()
+    (handle_a,) = pipe_receiver.recv()
+
+    handle_b, handle_c, handle_d = pipe_receiver.recv()
+
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tensor_a.dtype == torch.uint8
+        tensor_a[2:5] = torch.tensor((42, 43, 44))
+
+    with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
+        assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
+        assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
+        tensor_a += 1
+        tensor_b[...] = -1.337
+    assert cache.current_size_bytes == 809  # this checks a,b,c,d are allocated but b still awaits memory
+
+    dealloc_bcd_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 768  # only tensor a should be allocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a, handle_b):
+            pass  # one of handles (c) is deallocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_d):
+            pass  # handle_d is deallocated correctly, even though it is never used
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tuple(tensor_a[2:5]) == (43, 44, 45)
+
+    dealloc_a_event.set()
+    (handle_e,) = pipe_receiver.recv()  # e can finally be allocated
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 1536  # tensor e should finally be able to allocate
+
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a):
+            pass  # tensor a is no longer allocated
+    with cache.use_cache(handle_e) as (tensor_e,):
+        assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
+
+    dealloc_e_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 1792  # only tensor f is still allocated
+    dealloc_f_event.set()
+
+    alloc_process1.join()
+    alloc_process2.join()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 0
+    assert cache.current_size_bytes == 0
+    assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
+    assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"