|
@@ -8,6 +8,20 @@ from src.server.server import Server
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
+import re
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def parse_size_as_bytes(size: str) -> int:
|
|
|
|
+ """ parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471 """
|
|
|
|
+ units = {"B": 1, "KB": 2 ** 10, "MB": 2 ** 20, "GB": 2 ** 30, "TB": 2 ** 40, "PB": 2 ** 50}
|
|
|
|
+ size = size.strip().upper().rstrip("IB ")
|
|
|
|
+ if not size.endswith("B"):
|
|
|
|
+ size += "B"
|
|
|
|
+ if not re.match(r' ', size):
|
|
|
|
+ size = re.sub(r'([KMGT]?)', r' \1', size)
|
|
|
|
+ number, unit = [string.strip() for string in size.split()]
|
|
|
|
+ return int(float(number)*units[unit])
|
|
|
|
+
|
|
|
|
|
|
def main():
|
|
def main():
|
|
# fmt:off
|
|
# fmt:off
|
|
@@ -32,16 +46,19 @@ def main():
|
|
parser.add_argument('--min_batch_size', type=int, default=1,
|
|
parser.add_argument('--min_batch_size', type=int, default=1,
|
|
help='Minimum required batch size for all expert operations')
|
|
help='Minimum required batch size for all expert operations')
|
|
parser.add_argument('--max_batch_size', type=int, default=16384,
|
|
parser.add_argument('--max_batch_size', type=int, default=16384,
|
|
- help='The total number of examples in the same batch will not exceed this value')
|
|
|
|
|
|
+ help='The total number of tokens in the same batch will not exceed this value')
|
|
|
|
+ parser.add_argument('--inference_max_length', type=int, default=None,
|
|
|
|
+ help='Maximum total sequence length permitted per inference, defaults to max_batch_size tokens')
|
|
parser.add_argument('--cache_dir', type=str, default=None,
|
|
parser.add_argument('--cache_dir', type=str, default=None,
|
|
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
|
|
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
|
|
- parser.add_argument('--cache_size_bytes', type=int, default=None,
|
|
|
|
- help='The size of memory cache for storing past attention keys/values between inference steps')
|
|
|
|
parser.add_argument('--device', type=str, default=None, required=False,
|
|
parser.add_argument('--device', type=str, default=None, required=False,
|
|
help='all experts will use this device in torch notation; default: cuda if available else cpu')
|
|
help='all experts will use this device in torch notation; default: cuda if available else cpu')
|
|
parser.add_argument("--torch_dtype", type=str, default="auto",
|
|
parser.add_argument("--torch_dtype", type=str, default="auto",
|
|
help="Use this dtype to store block weights and do computations. "
|
|
help="Use this dtype to store block weights and do computations. "
|
|
"By default, respect the dtypes in the pre-trained state dict.")
|
|
"By default, respect the dtypes in the pre-trained state dict.")
|
|
|
|
+ parser.add_argument('--attn_cache_bytes', type=str, default=None,
|
|
|
|
+ help='The size of GPU memory allocated for storing past attention keys/values between inference'
|
|
|
|
+ ' steps; examples: 500MB or 4.2GB or 1073741824 ; assumes 1KB = 1kB = 1024 bytes')
|
|
parser.add_argument('--revision', type=str, default='main',
|
|
parser.add_argument('--revision', type=str, default='main',
|
|
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
|
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
|
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
|
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
|
@@ -81,10 +98,15 @@ def main():
|
|
compression_type = args.pop("compression")
|
|
compression_type = args.pop("compression")
|
|
compression = getattr(CompressionType, compression_type)
|
|
compression = getattr(CompressionType, compression_type)
|
|
|
|
|
|
|
|
+ cache_size_bytes = args.pop("cache_size_bytes")
|
|
|
|
+ if cache_size_bytes is not None:
|
|
|
|
+ cache_size_bytes = parse_size_as_bytes(cache_size_bytes)
|
|
|
|
+ assert isinstance(cache_size_bytes, (int, type(None))), "invalid value for cache_size_bytes, try 1.5GB or 700MB"
|
|
|
|
+
|
|
use_auth_token = args.pop("use_auth_token")
|
|
use_auth_token = args.pop("use_auth_token")
|
|
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
|
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
|
|
|
|
|
- server = Server.create(**args, start=True, compression=compression)
|
|
|
|
|
|
+ server = Server.create(**args, start=True, compression=compression, cache_size_bytes=cache_size_bytes)
|
|
|
|
|
|
try:
|
|
try:
|
|
server.join()
|
|
server.join()
|