|
@@ -110,7 +110,7 @@ class Server(threading.Thread):
|
|
|
torch_dtype: str = "auto",
|
|
|
revision: str = "main",
|
|
|
cache_dir: Optional[str] = None,
|
|
|
- cache_size_bytes: Optional[int] = None,
|
|
|
+ attention_cache_bytes: Optional[int] = None,
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
initial_peers: Sequence[str] = (),
|
|
|
compression=CompressionType.NONE,
|
|
@@ -146,7 +146,7 @@ class Server(threading.Thread):
|
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
- memory_cache = MemoryCache(device, cache_size_bytes)
|
|
|
+ memory_cache = MemoryCache(device, attention_cache_bytes)
|
|
|
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
if throughput in ["auto", "eval"]:
|
|
@@ -233,6 +233,7 @@ class Server(threading.Thread):
|
|
|
blocks,
|
|
|
throughput=throughput,
|
|
|
num_connection_handlers=num_handlers,
|
|
|
+ inference_max_length=inference_max_length,
|
|
|
device=device,
|
|
|
stats_report_interval=stats_report_interval,
|
|
|
update_period=update_period,
|