瀏覽代碼

use humanfriendly

Aleksandr Borzunov 3 年之前
父節點
當前提交
0f5c427969
共有 2 個文件被更改,包括 4 次插入16 次删除
  1. 3 16
      cli/run_server.py
  2. 1 0
      requirements.txt

+ 3 - 16
cli/run_server.py

@@ -2,26 +2,13 @@ import configargparse
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from humanfriendly import parse_size
 
 from src.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-import re
-
-
-def parse_size_as_bytes(size: str) -> int:
-    """parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471"""
-    units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40, "PB": 2**50}
-    size = size.strip().upper().replace("IB", "B")
-    if not size.endswith("B"):
-        size += "B"
-    if not re.match(r" ", size):
-        size = re.sub(r"([KMGT]?B)", r" \1", size)
-    number, unit = [string.strip() for string in size.split()]
-    return int(float(number) * units[unit])
-
 
 def main():
     # fmt:off
@@ -58,7 +45,7 @@ def main():
                              "By default, respect the dtypes in the pre-trained state dict.")
     parser.add_argument('--attn_cache_size', type=str, default=None,
                         help='The size of GPU memory allocated for storing past attention keys/values between inference'
-                             ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); assumes 1KB = 1kB = 1024 bytes')
+                             ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
     parser.add_argument('--revision', type=str, default='main',
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@@ -100,7 +87,7 @@ def main():
 
     attn_cache_size = args.pop("attn_cache_size")
     if attn_cache_size is not None:
-        attn_cache_size = parse_size_as_bytes(attn_cache_size)
+        attn_cache_size = parse_size(attn_cache_size)
     assert isinstance(
         attn_cache_size, (int, type(None))
     ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"

+ 1 - 0
requirements.txt

@@ -1,5 +1,6 @@
 torch==1.12.0
 accelerate==0.10.0
 huggingface-hub==0.7.0
+humanfriendly
 https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip
 https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip