|
@@ -2,26 +2,13 @@ import configargparse
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.utils.limits import increase_file_limit
|
|
from hivemind.utils.limits import increase_file_limit
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
+from humanfriendly import parse_size
|
|
|
|
|
|
from src.server.server import Server
|
|
from src.server.server import Server
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
-import re
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def parse_size_as_bytes(size: str) -> int:
|
|
|
|
- """parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471"""
|
|
|
|
- units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40, "PB": 2**50}
|
|
|
|
- size = size.strip().upper().replace("IB", "B")
|
|
|
|
- if not size.endswith("B"):
|
|
|
|
- size += "B"
|
|
|
|
- if not re.match(r" ", size):
|
|
|
|
- size = re.sub(r"([KMGT]?B)", r" \1", size)
|
|
|
|
- number, unit = [string.strip() for string in size.split()]
|
|
|
|
- return int(float(number) * units[unit])
|
|
|
|
-
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|
|
# fmt:off
|
|
# fmt:off
|
|
@@ -58,7 +45,7 @@ def main():
|
|
"By default, respect the dtypes in the pre-trained state dict.")
|
|
"By default, respect the dtypes in the pre-trained state dict.")
|
|
parser.add_argument('--attn_cache_size', type=str, default=None,
|
|
parser.add_argument('--attn_cache_size', type=str, default=None,
|
|
help='The size of GPU memory allocated for storing past attention keys/values between inference'
|
|
help='The size of GPU memory allocated for storing past attention keys/values between inference'
|
|
- ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); assumes 1KB = 1kB = 1024 bytes')
|
|
|
|
|
|
+ ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
|
|
parser.add_argument('--revision', type=str, default='main',
|
|
parser.add_argument('--revision', type=str, default='main',
|
|
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
|
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
|
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
|
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
|
@@ -100,7 +87,7 @@ def main():
|
|
|
|
|
|
attn_cache_size = args.pop("attn_cache_size")
|
|
attn_cache_size = args.pop("attn_cache_size")
|
|
if attn_cache_size is not None:
|
|
if attn_cache_size is not None:
|
|
- attn_cache_size = parse_size_as_bytes(attn_cache_size)
|
|
|
|
|
|
+ attn_cache_size = parse_size(attn_cache_size)
|
|
assert isinstance(
|
|
assert isinstance(
|
|
attn_cache_size, (int, type(None))
|
|
attn_cache_size, (int, type(None))
|
|
), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
|
|
), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
|