瀏覽代碼

Use number of tokens for attn_cache_size (#286)

* Use number of tokens for attn_cache_size

* Fix cache_bytes_per_block

* Rename attn_cache_size to attn_cache_tokens
Max Ryabinin 2 年之前
父節點
當前提交
5c0733711a
共有 4 個文件被更改,包括 18 次插入32 次删除
  1. 1 1
      .github/workflows/run-tests.yaml
  2. 5 15
      src/petals/cli/run_server.py
  3. 0 1
      src/petals/server/backend.py
  4. 12 15
      src/petals/server/server.py

+ 1 - 1
.github/workflows/run-tests.yaml

@@ -86,7 +86,7 @@ jobs:
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
-            --torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log &
+            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log &
           SERVER1_PID=$!
 
           sleep 5  # wait for the first server to initialize DHT

+ 5 - 15
src/petals/cli/run_server.py

@@ -7,7 +7,7 @@ from hivemind.utils.logging import get_logger
 from humanfriendly import parse_size
 
 from petals.constants import PUBLIC_INITIAL_PEERS
-from petals.server.server import Server
+from petals.server.server import DTYPE_MAP, Server
 from petals.utils.version import validate_version
 
 logger = get_logger(__name__)
@@ -78,14 +78,12 @@ def main():
 
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all blocks will use this device in torch notation; default: cuda if available else cpu')
-    parser.add_argument("--torch_dtype", type=str, default="auto",
+    parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--attn_cache_size', type=str, default=None,
-                        help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
-                             'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
-                             'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
-                             'The latter is the hidden size of the bigscience/bloom-petals model.')
+    parser.add_argument('--attn_cache_tokens', type=int, default=8192,
+                        help='The number of past attention key/value pairs that will be stored between inference steps. '
+                             'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
     parser.add_argument('--alloc_timeout', type=float, default=60,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
@@ -178,13 +176,6 @@ def main():
     compression_type = args.pop("compression").upper()
     compression = getattr(CompressionType, compression_type)
 
-    attn_cache_size = args.pop("attn_cache_size")
-    if attn_cache_size is not None:
-        attn_cache_size = parse_size(attn_cache_size)
-    assert isinstance(
-        attn_cache_size, (int, type(None))
-    ), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
-
     max_disk_space = args.pop("max_disk_space")
     if max_disk_space is not None:
         max_disk_space = parse_size(max_disk_space)
@@ -207,7 +198,6 @@ def main():
         announce_maddrs=announce_maddrs,
         compression=compression,
         max_disk_space=max_disk_space,
-        attn_cache_size=attn_cache_size,
     )
     try:
         server.run()

+ 0 - 1
src/petals/server/backend.py

@@ -48,7 +48,6 @@ class TransformerBackend(ModuleBackend):
             self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
         )
 
-        assert backend_dtype is not None
         self.dtype = backend_dtype
         self.shard_num_heads = []
         for shard in self.module.module_shards:

+ 12 - 15
src/petals/server/server.py

@@ -56,7 +56,7 @@ class Server:
         revision: str = "main",
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
-        attn_cache_size: Optional[int] = None,
+        attn_cache_tokens: int = 8192,
         alloc_timeout: float = 60,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
@@ -148,9 +148,7 @@ class Server:
             device = torch.device(device.type, index=0)
         self.device = device
 
-        if isinstance(torch_dtype, str):
-            torch_dtype = DTYPE_MAP[torch_dtype]
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        torch_dtype = DTYPE_MAP[torch_dtype]
         self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
 
         if tensor_parallel_devices is None:
@@ -165,6 +163,9 @@ class Server:
         self.load_in_8bit = load_in_8bit
         logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
 
+        max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens
+        self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8
+
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
             num_blocks = self._choose_num_blocks()
@@ -179,13 +180,10 @@ class Server:
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
 
         gib = 1024**3
-        if attn_cache_size is None:
-            # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
-            attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
-
-        self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
-        logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
+        self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
+        logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
+        self.alloc_timeout = alloc_timeout
         if cache_dir is None:
             cache_dir = DEFAULT_CACHE_DIR
         self.cache_dir = cache_dir
@@ -236,10 +234,9 @@ class Server:
 
         # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
         gib = 1024**3
-        attn_cache_per_block = 0.5 * gib * num_devices  # TODO: This does not account for manually set --attn_cache_size
         autograd_memory = 2 * gib * num_devices  # GPU memory used for intermediate tensors in rpc_backward
 
-        num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
+        num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
         logger.info(
@@ -256,7 +253,7 @@ class Server:
                 prefix=self.prefix,
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
-                attn_cache_size=self.attn_cache_size,
+                attn_cache_bytes=self.attn_cache_bytes,
                 alloc_timeout=self.alloc_timeout,
                 throughput=self.throughput,
                 block_indices=block_indices,
@@ -356,7 +353,7 @@ class ModuleContainer(threading.Thread):
         prefix: str,
         converted_model_name_or_path: str,
         block_config: BloomConfig,
-        attn_cache_size: int,
+        attn_cache_bytes: int,
         alloc_timeout: float,
         throughput: float,
         block_indices: List[int],
@@ -390,7 +387,7 @@ class ModuleContainer(threading.Thread):
 
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
 
-        memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
+        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
         blocks = {}
         try:
             for module_uid, block_index in zip(module_uids, block_indices):