Browse Source

fix humane parser

justheuristic 3 năm trước cách đây
mục cha
commit
2110b2cdc8
1 tập tin đã thay đổi với 10 bổ sung9 xóa
  1. 10 9
      cli/run_server.py

+ 10 - 9
cli/run_server.py

@@ -14,11 +14,11 @@ import re
 def parse_size_as_bytes(size: str) -> int:
     """parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471"""
     units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40, "PB": 2**50}
-    size = size.strip().upper().rstrip("IB ")
+    size = size.strip().upper().replace("IB", "B")
     if not size.endswith("B"):
         size += "B"
     if not re.match(r" ", size):
-        size = re.sub(r"([KMGT]?)", r" \1", size)
+        size = re.sub(r"([KMGT]?B)", r" \1", size)
     number, unit = [string.strip() for string in size.split()]
     return int(float(number) * units[unit])
 
@@ -56,9 +56,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--attn_cache_bytes', type=str, default=None,
+    parser.add_argument('--attention_cache_bytes', type=str, default=None,
                         help='The size of GPU memory allocated for storing past attention keys/values between inference'
-                             ' steps; examples: 500MB or 4.2GB or 1073741824 ; assumes 1KB = 1kB = 1024 bytes')
+                             ' steps; examples: 500MB or 1.2GB or 1073741824 ; assumes 1KB = 1kB = 1024 bytes')
     parser.add_argument('--revision', type=str, default='main',
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@@ -98,15 +98,16 @@ def main():
     compression_type = args.pop("compression")
     compression = getattr(CompressionType, compression_type)
 
-    cache_size_bytes = args.pop("cache_size_bytes")
-    if cache_size_bytes is not None:
-        cache_size_bytes = parse_size_as_bytes(cache_size_bytes)
-    assert isinstance(cache_size_bytes, (int, type(None))), "invalid value for cache_size_bytes, try 1.5GB or 700MB"
+    attention_cache_bytes = args.pop("attention_cache_bytes")
+    if attention_cache_bytes is not None:
+        attention_cache_bytes = parse_size_as_bytes(attention_cache_bytes)
+    assert isinstance(attention_cache_bytes, (int, type(None))), "unrecognized value for attention_cache_bytes," \
+                                                                 " examples: 1.5GB or 1500MB or 1572864000"
 
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
-    server = Server.create(**args, start=True, compression=compression, cache_size_bytes=cache_size_bytes)
+    server = Server.create(**args, start=True, compression=compression, cache_size_bytes=attention_cache_bytes)
 
     try:
         server.join()