|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import multiprocessing as mp
|
|
|
import threading
|
|
|
-from typing import Dict, Optional, Sequence, Union
|
|
|
+from typing import Dict, Literal, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
@@ -19,6 +19,7 @@ from src.server.backend import TransformerBackend
|
|
|
from src.server.block_selection import choose_best_blocks
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
+from src.server.throughput import get_host_throughput
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -95,7 +96,7 @@ class Server(threading.Thread):
|
|
|
cls,
|
|
|
prefix: Optional[str],
|
|
|
converted_model_name_or_path: str,
|
|
|
- throughput: float,
|
|
|
+ throughput: Union[float, Literal['auto', 'eval']],
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
num_handlers: Optional[int] = None,
|
|
@@ -103,7 +104,7 @@ class Server(threading.Thread):
|
|
|
max_batch_size: int = 4096,
|
|
|
torch_dtype: str = "auto",
|
|
|
cache_size_bytes: Optional[int] = None,
|
|
|
- device: Union[str, torch.device] = None,
|
|
|
+ device: Optional[Union[str, torch.device]] = None,
|
|
|
initial_peers: Sequence[str] = (),
|
|
|
compression=CompressionType.NONE,
|
|
|
stats_report_interval: Optional[int] = None,
|
|
@@ -136,6 +137,10 @@ class Server(threading.Thread):
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
memory_cache = MemoryCache(device, cache_size_bytes)
|
|
|
|
|
|
+ assert isinstance(throughput, float) or throughput in ['auto', 'eval']
|
|
|
+ if throughput in ['auto', 'eval']:
|
|
|
+ throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
|
|
|
+
|
|
|
if isinstance(torch_dtype, str):
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|