Aleksandr Borzunov 3 năm trước cách đây
mục cha
commit
612100ba3e
3 tập tin đã thay đổi với 12 bổ sung12 xóa
  1. 1 1
      .github/workflows/run-tests.yaml
  2. 9 9
      cli/run_server.py
  3. 2 2
      src/server/server.py

+ 1 - 1
.github/workflows/run-tests.yaml

@@ -81,7 +81,7 @@ jobs:
 
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 \
-            --throughput 1 --attention_cache_bytes 0.2GiB &> server1.log &
+            --throughput 1 --attention_cache_size 0.2GiB &> server1.log &
           SERVER1_PID=$!
           
           sleep 5  # wait for the first server to initialize DHT

+ 9 - 9
cli/run_server.py

@@ -56,9 +56,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--attention_cache_bytes', type=str, default=None,
+    parser.add_argument('--attn_cache_size', type=str, default=None,
                         help='The size of GPU memory allocated for storing past attention keys/values between inference'
-                             ' steps; examples: 500MB or 1.2GB or 1073741824 ; assumes 1KB = 1kB = 1024 bytes')
+                             ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); assumes 1KB = 1kB = 1024 bytes')
     parser.add_argument('--revision', type=str, default='main',
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@@ -98,17 +98,17 @@ def main():
     compression_type = args.pop("compression")
     compression = getattr(CompressionType, compression_type)
 
-    attention_cache_bytes = args.pop("attention_cache_bytes")
-    if attention_cache_bytes is not None:
-        attention_cache_bytes = parse_size_as_bytes(attention_cache_bytes)
-    assert isinstance(attention_cache_bytes, (int, type(None))), (
-        "unrecognized value for attention_cache_bytes," " examples: 1.5GB or 1500MB or 1572864000"
-    )
+    attn_cache_size = args.pop("attention_cache_bytes")
+    if attn_cache_size is not None:
+        attention_cache_bytes = parse_size_as_bytes(attn_cache_size)
+    assert isinstance(
+        attn_cache_size, (int, type(None))
+    ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
 
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
-    server = Server.create(**args, start=True, compression=compression, attention_cache_bytes=attention_cache_bytes)
+    server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
 
     try:
         server.join()

+ 2 - 2
src/server/server.py

@@ -110,7 +110,7 @@ class Server(threading.Thread):
         torch_dtype: str = "auto",
         revision: str = "main",
         cache_dir: Optional[str] = None,
-        attention_cache_bytes: Optional[int] = None,
+        b: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
         initial_peers: Sequence[str] = (),
         compression=CompressionType.NONE,
@@ -146,7 +146,7 @@ class Server(threading.Thread):
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        memory_cache = MemoryCache(device, attention_cache_bytes)
+        memory_cache = MemoryCache(device, attn_cache_size)
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]: