瀏覽代碼

the (still) reasonable version

Your Name 2 年之前
父節點
當前提交
cc67c332a6

+ 2 - 2
src/petals/server/backend.py

@@ -16,7 +16,7 @@ from transformers import PretrainedConfig
 from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
-from petals.utils.misc import is_dummy
+from petals.utils.misc import is_dummy, get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -74,7 +74,7 @@ class TransformerBackend(ModuleBackend):
 
         self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
         for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
-            self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
+            self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
 
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""

+ 3 - 2
src/petals/server/block_utils.py

@@ -5,6 +5,7 @@ from accelerate import init_empty_weights
 from transformers import PretrainedConfig
 
 from petals.utils.convert_block import QuantType
+from petals.utils.misc import get_size_in_bytes
 
 
 def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@@ -36,7 +37,7 @@ def get_block_size(
     if location == "memory":
         if quant_type == QuantType.NONE:
             dtype = resolve_block_dtype(config, dtype)
-            bytes_per_value = torch.finfo(dtype).bits // 8
+            bytes_per_value = get_size_in_bytes(dtype)
         elif quant_type == QuantType.INT8:
             bytes_per_value = 1
         elif quant_type == QuantType.NF4:
@@ -45,6 +46,6 @@ def get_block_size(
             raise ValueError(f"Unsupported quant_type={quant_type}")
     elif location == "disk":
         dtype = resolve_block_dtype(config, "auto")
-        bytes_per_value = torch.finfo(dtype).bits // 8
+        bytes_per_value = get_size_in_bytes(dtype)
 
     return round(n_params * bytes_per_value * (1 + eps))

+ 3 - 2
src/petals/server/handler.py

@@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 active_adapter = self._get_active_adapter(metadata)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
+                alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                 assert isinstance(
@@ -167,7 +168,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
                 prefix_length = 0
 
-                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
+                async with self._allocate_cache(requested_backends, batch_size, max_length, alloc_timeout) as cache_handles:
                     assert len(cache_handles) == len(requested_backends)
                     first_request = request
                     background_tasks = set()
@@ -567,7 +568,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     @contextlib.asynccontextmanager
     async def _allocate_cache(
-        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int, timeout: Optional[float],
     ) -> Sequence[Sequence[Handle]]:
         """
         Allocate memory cache for all transformer blocks, return cache handle

+ 12 - 10
src/petals/server/memory_cache.py

@@ -17,6 +17,7 @@ import torch
 from hivemind.utils import TensorDescriptor, get_logger
 
 from petals.utils.asyncio import shield_and_wait
+from petals.utils.misc import get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -26,9 +27,8 @@ Handle = int
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
-    def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
+    def __init__(self, max_size_bytes: Optional[int]):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
-        self.alloc_timeout = alloc_timeout
         self._lock_metadata = mp.Lock()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
@@ -60,11 +60,12 @@ class MemoryCache:
         self._handle_counter.value = value
 
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
+    async def allocate_cache(self, *descriptors: TensorDescriptor, timeout: Optional[float] = None) -> AsyncContextManager[Sequence[Handle]]:
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 
         :param descriptors: one or more tensors tensor of this size, dtype, etc
+        :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
 
         :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
           if not, it will count maximum tensor allocation across devices for the purposes of size limit
@@ -76,7 +77,7 @@ class MemoryCache:
         assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
         max_alloc_size = self.get_allocation_size(*descriptors)
 
-        gib = 1024**3
+        gib = 1
         cur_size, max_size = self.current_size_bytes, self.max_size_bytes
         friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
         logger.info(
@@ -84,24 +85,26 @@ class MemoryCache:
             f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
         )
 
-        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
+        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
         try:
             handles = await shield_and_wait(alloc_task)
-            logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
+            logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)")
             yield handles
         finally:
+            logger.info(f"rpc_inference.dealloc-began(size={max_alloc_size / gib:.2f} GiB)")
             await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
+            logger.info(f"rpc_inference.dealloc-done(size={max_alloc_size / gib:.2f} GiB)")
 
     @staticmethod
     def get_allocation_size(*descriptors: TensorDescriptor) -> int:
         """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
         alloc_size_by_device = {}
         for descr in descriptors:
-            tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+            tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
             alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
         return max(alloc_size_by_device.values())
 
-    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
+    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]) -> Sequence[Handle]:
         """
         This method should be called inside asyncio.shield() because:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
@@ -110,7 +113,7 @@ class MemoryCache:
         loop = asyncio.get_event_loop()
         async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
             if self.current_size_bytes + alloc_size > self.max_size_bytes:
-                await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
+                await loop.run_in_executor(None, self._wait_until_available, alloc_size, timeout)
             async with hivemind.utils.enter_asynchronously(self._lock_metadata):
                 handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
                 self.current_size_bytes += alloc_size
@@ -124,7 +127,6 @@ class MemoryCache:
             - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
             - _schedule_free() must finish freeing memory even in case of cancellation
         """
-
         if alloc_task.exception() is not None:
             return
         handles = alloc_task.result()

+ 3 - 7
src/petals/server/server.py

@@ -31,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
+from petals.utils.misc import get_size_in_bytes
 from petals.utils.ping import PingAggregator
 from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo
@@ -63,7 +64,6 @@ class Server:
         revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
-        alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
@@ -189,7 +189,7 @@ class Server:
             attn_cache_tokens = 32768 if is_multiquery_attn else 8192
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block //= self.block_config.num_key_value_groups
-        self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
+        self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
 
         # For disk cache
         self.cache_dir = cache_dir
@@ -213,8 +213,6 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
-        self.alloc_timeout = alloc_timeout
-
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput_info = get_server_throughput(
@@ -306,7 +304,6 @@ class Server:
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
-                alloc_timeout=self.alloc_timeout,
                 server_info=self.server_info,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
@@ -407,7 +404,6 @@ class ModuleContainer(threading.Thread):
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
-        alloc_timeout: float,
         server_info: ServerInfo,
         block_indices: List[int],
         min_batch_size: int,
@@ -427,7 +423,7 @@ class ModuleContainer(threading.Thread):
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
-        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
+        memory_cache = MemoryCache(attn_cache_bytes)
 
         server_info.state = ServerState.JOINING
         dht_announcer = ModuleAnnouncerThread(

+ 10 - 0
src/petals/utils/misc.py

@@ -5,3 +5,13 @@ DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter par
 
 def is_dummy(tensor: torch.Tensor):
     return tensor.numel() == 0
+
+
+SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.int8: 1, torch.qint32: 4}
+
+
+def get_size_in_bytes(dtype: torch.dtype) -> int:
+    if dtype in SPECIAL_DTYPE_SIZES:
+        return SPECIAL_DTYPE_SIZES[dtype]
+    get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
+    return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8

+ 2 - 1
src/petals/utils/peft.py

@@ -19,6 +19,7 @@ from transformers.utils import get_file_from_repo
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_block import QuantType
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.misc import get_size_in_bytes
 
 logger = get_logger(__name__)
 
@@ -284,5 +285,5 @@ def estimate_adapter_memory_per_block(
                 block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
             )
         adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
-    bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
+    bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
     return adapter_parameters * bytes_per_parameter

+ 132 - 0
tests/test_cache.py

@@ -0,0 +1,132 @@
+import random
+from typing import Optional
+
+import pytest
+import torch
+from hivemind import TensorDescriptor
+
+from petals.server.memory_cache import MemoryCache, AllocationFailed
+import asyncio
+from petals.utils.misc import get_size_in_bytes
+import multiprocessing as mp
+import pytest_asyncio  # make sure the module exists; otherwise the test will be skipped
+
+
+def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
+    if dtype is None:
+        dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
+    elem_size_bytes = get_size_in_bytes(dtype)
+    descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
+    return descr
+
+
+@pytest.mark.asyncio
+async def test_cache_usage():
+    cache = MemoryCache(max_size_bytes=2048)
+    alloc_event, dealloc_e_event, dealloc_bcd_event, dealloc_a_event = mp.Event(), mp.Event(), mp.Event(), mp.Event()
+    pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
+    with pytest.raises(AssertionError):
+        async with cache.allocate_cache(_make_tensor_descriptor(123)):
+            pass  # fails because cache must be allocated from another process
+
+    descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8))         # 768 bytes
+    descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64))        # 8 bytes
+    descr_c = TensorDescriptor.from_tensor(torch.empty((33, ), dtype=torch.bool))       # 33 bytes
+    descr_d = TensorDescriptor.from_tensor(torch.empty((0, ), dtype=torch.int64))       # 0 bytes
+    descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16))  # 1536 bytes
+    descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8))     # 1792 bytes
+    descr_g = TensorDescriptor.from_tensor(torch.empty((1793,), dtype=torch.uint8))  # 1792 bytes
+
+    async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
+        loop = asyncio.get_event_loop()
+        async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
+            pipe_sender.send(handles)
+            await loop.run_in_executor(None, dealloc_event.wait)
+
+    async def _allocate_af():
+        alloc_event.wait()
+        print("BEGAN AF")
+        try:
+            async with cache.allocate_cache(descr_g):
+                allocate_f_task = asyncio.create_task(_allocate_and_wait(mp.Event(), descr_f))  # klogs the cache
+                print("CANCELLED")
+                raise asyncio.CancelledError()
+        except asyncio.CancelledError:
+            pass
+        allocate_f_task.cancel()  # unklog the cache
+
+        allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
+        await allocate_a_task
+
+    alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
+    alloc_process1.start()
+
+    async def _allocate_bcde():
+        await asyncio.sleep(0.2)  # ensure that the other tensor is always allocated (and sent through pipe) first
+        print("BEGAN BCDE")
+        allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
+        allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit
+        await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
+
+    alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
+    alloc_process2.start()
+    assert cache.current_size_bytes == 0
+    alloc_event.set()
+    handle_a, = pipe_receiver.recv()
+
+    handle_b, handle_c, handle_d = pipe_receiver.recv()
+
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tensor_a.dtype == torch.uint8
+        tensor_a[2:5] = torch.tensor((42, 43, 44))
+
+    with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
+        assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
+        assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
+        tensor_a += 1
+        tensor_b[...] = -1.337
+    assert cache.current_size_bytes == 809  # this checks a,b,c,d are allocated but b still awaits memory
+
+    dealloc_bcd_event.set()
+    await asyncio.sleep(0.1)
+    assert cache.current_size_bytes == 768  # only tensor a is allocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a, handle_b):
+            pass  # one of handles (c) is deallocated
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_d):
+            pass  # handle_e is deallocated, even though it is never used
+    with cache.use_cache(handle_a) as (tensor_a,):
+        assert tuple(tensor_a[2:5]) == (43, 44, 45)
+
+    dealloc_a_event.set()
+    handle_e, = pipe_receiver.recv()  # e can finally be allocated
+    assert cache.current_size_bytes == 1536  # tensor e should finally be able to allocate
+
+    with pytest.raises(KeyError):
+        with cache.use_cache(handle_a):
+            pass  # tensor a is no longer allocated
+    with cache.use_cache(handle_e) as (tensor_e,):
+        assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
+
+    dealloc_e_event.set()
+    alloc_process1.join(1)
+    alloc_process2.join(1)
+    assert cache.current_size_bytes == 0
+    assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
+    assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"
+
+    # cache.runtime_pid += 1  # pretend we're another process
+    # async with cache.allocate_cache(_make_tensor_descriptor(768)) as a:
+    #     pass
+    #
+    #
+    # async with cache.allocate_cache(_make_tensor_descriptor(768)):
+    #     async with cache.allocate_cache(_make_tensor_descriptor(1024)):
+    #         async with cache.allocate_cache(_make_tensor_descriptor(512), _make_tensor_descriptor(64)):
+    #             async with cache.allocate_cache(_make_tensor_descriptor(1536)):
+    #                 with pytest.raises(TimeoutError):
+    #                     async with cache.allocate_cache(_make_tensor_descriptor(256), ):
+    #                         pass
+    #                 async with cache.allocate_cache(_make_tensor_descriptor(192)):
+    #                     pass