|
@@ -105,6 +105,7 @@ class Server(threading.Thread):
|
|
|
min_batch_size: int = 1,
|
|
|
max_batch_size: int = 4096,
|
|
|
torch_dtype: str = "auto",
|
|
|
+ revision: str = "main",
|
|
|
cache_dir: Optional[str] = None,
|
|
|
cache_size_bytes: Optional[int] = None,
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
@@ -151,7 +152,7 @@ class Server(threading.Thread):
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
|
block_config = BloomConfig.from_pretrained(
|
|
|
- converted_model_name_or_path, use_auth_token=use_auth_token, revision="client"
|
|
|
+ converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
|
|
|
)
|
|
|
|
|
|
if block_indices is not None:
|