|
@@ -56,9 +56,9 @@ def main():
|
|
|
parser.add_argument("--torch_dtype", type=str, default="auto",
|
|
|
help="Use this dtype to store block weights and do computations. "
|
|
|
"By default, respect the dtypes in the pre-trained state dict.")
|
|
|
- parser.add_argument('--attention_cache_bytes', type=str, default=None,
|
|
|
+ parser.add_argument('--attn_cache_size', type=str, default=None,
|
|
|
help='The size of GPU memory allocated for storing past attention keys/values between inference'
|
|
|
- ' steps; examples: 500MB or 1.2GB or 1073741824 ; assumes 1KB = 1kB = 1024 bytes')
|
|
|
+ ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); assumes 1KB = 1kB = 1024 bytes')
|
|
|
parser.add_argument('--revision', type=str, default='main',
|
|
|
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
|
|
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
|
@@ -98,17 +98,17 @@ def main():
|
|
|
compression_type = args.pop("compression")
|
|
|
compression = getattr(CompressionType, compression_type)
|
|
|
|
|
|
- attention_cache_bytes = args.pop("attention_cache_bytes")
|
|
|
- if attention_cache_bytes is not None:
|
|
|
- attention_cache_bytes = parse_size_as_bytes(attention_cache_bytes)
|
|
|
- assert isinstance(attention_cache_bytes, (int, type(None))), (
|
|
|
- "unrecognized value for attention_cache_bytes," " examples: 1.5GB or 1500MB or 1572864000"
|
|
|
- )
|
|
|
+ attn_cache_size = args.pop("attention_cache_bytes")
|
|
|
+ if attn_cache_size is not None:
|
|
|
+ attention_cache_bytes = parse_size_as_bytes(attn_cache_size)
|
|
|
+ assert isinstance(
|
|
|
+ attn_cache_size, (int, type(None))
|
|
|
+ ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
|
|
|
|
|
|
use_auth_token = args.pop("use_auth_token")
|
|
|
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
|
|
|
|
|
- server = Server.create(**args, start=True, compression=compression, attention_cache_bytes=attention_cache_bytes)
|
|
|
+ server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
|
|
|
|
|
|
try:
|
|
|
server.join()
|