Browse Source

Add cache allocation timeout

Aleksandr Borzunov 2 năm trước cách đây
mục cha
commit
8af3ac3623
3 tập tin đã thay đổi với 8 bổ sung3 xóa
  1. 3 0
      cli/run_server.py
  2. 3 2
      src/server/cache.py
  3. 2 1
      src/server/server.py

+ 3 - 0
cli/run_server.py

@@ -57,6 +57,9 @@ def main():
     parser.add_argument('--attn_cache_size', type=str, default=None,
     parser.add_argument('--attn_cache_size', type=str, default=None,
                         help='The size of GPU memory allocated for storing past attention keys/values between inference'
                         help='The size of GPU memory allocated for storing past attention keys/values between inference'
                              ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
                              ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
+    parser.add_argument('--alloc_timeout', type=float, default=60,
+                        help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
+                             'before rejecting the request')
     parser.add_argument('--revision', type=str, default='main',
     parser.add_argument('--revision', type=str, default='main',
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")

+ 3 - 2
src/server/cache.py

@@ -26,8 +26,9 @@ Handle = int
 class MemoryCache:
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
 
-    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
+    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
+        self.alloc_timeout = alloc_timeout
         self.device = device
         self.device = device
         self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
@@ -75,7 +76,7 @@ class MemoryCache:
         try:
         try:
             async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
             async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
+                    await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes, timeout=self.alloc_timeout)
                 async with hivemind.utils.enter_asynchronously(self._lock_metadata):
                 async with hivemind.utils.enter_asynchronously(self._lock_metadata):
                     allocated_handle = int(self.handle_counter)
                     allocated_handle = int(self.handle_counter)
                     self.current_size_bytes += allocated_size_bytes
                     self.current_size_bytes += allocated_size_bytes

+ 2 - 1
src/server/server.py

@@ -55,6 +55,7 @@ class Server:
         revision: str = "main",
         revision: str = "main",
         cache_dir: Optional[str] = None,
         cache_dir: Optional[str] = None,
         attn_cache_size: Optional[int] = None,
         attn_cache_size: Optional[int] = None,
+        alloc_timeout: float = 60,
         device: Optional[Union[str, torch.device]] = None,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         stats_report_interval: Optional[int] = None,
@@ -110,7 +111,7 @@ class Server:
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         self.device = device
         self.device = device
 
 
-        self.memory_cache = MemoryCache(device, attn_cache_size)
+        self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
 
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
         if throughput in ["auto", "eval"]: