|
@@ -29,6 +29,7 @@ from petals.server.handler import TransformerConnectionHandler
|
|
from petals.server.memory_cache import MemoryCache
|
|
from petals.server.memory_cache import MemoryCache
|
|
from petals.server.throughput import get_host_throughput
|
|
from petals.server.throughput import get_host_throughput
|
|
from petals.utils.convert_8bit import replace_8bit_linear
|
|
from petals.utils.convert_8bit import replace_8bit_linear
|
|
|
|
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
@@ -56,6 +57,7 @@ class Server:
|
|
torch_dtype: str = "auto",
|
|
torch_dtype: str = "auto",
|
|
revision: str = "main",
|
|
revision: str = "main",
|
|
cache_dir: Optional[str] = None,
|
|
cache_dir: Optional[str] = None,
|
|
|
|
+ max_disk_space: Optional[int] = None,
|
|
attn_cache_size: Optional[int] = None,
|
|
attn_cache_size: Optional[int] = None,
|
|
alloc_timeout: float = 60,
|
|
alloc_timeout: float = 60,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
@@ -82,7 +84,6 @@ class Server:
|
|
self.num_handlers = num_handlers
|
|
self.num_handlers = num_handlers
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
self.inference_max_length = inference_max_length
|
|
self.inference_max_length = inference_max_length
|
|
- self.cache_dir = cache_dir
|
|
|
|
self.compression = compression
|
|
self.compression = compression
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
@@ -117,7 +118,8 @@ class Server:
|
|
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
|
|
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
- logger.info("Connecting to the public Petals swarm")
|
|
|
|
|
|
+ logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
|
|
|
|
+ logger.info("Please check that your server is reachable at http://health.petals.ml")
|
|
else:
|
|
else:
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
@@ -158,6 +160,11 @@ class Server:
|
|
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
|
|
|
|
|
+ if cache_dir is None:
|
|
|
|
+ cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
+ self.cache_dir = cache_dir
|
|
|
|
+ self.max_disk_space = max_disk_space
|
|
|
|
+
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
if throughput in ["auto", "eval"]:
|
|
if throughput in ["auto", "eval"]:
|
|
throughput = get_host_throughput(
|
|
throughput = get_host_throughput(
|
|
@@ -213,6 +220,7 @@ class Server:
|
|
inference_max_length=self.inference_max_length,
|
|
inference_max_length=self.inference_max_length,
|
|
torch_dtype=self.torch_dtype,
|
|
torch_dtype=self.torch_dtype,
|
|
cache_dir=self.cache_dir,
|
|
cache_dir=self.cache_dir,
|
|
|
|
+ max_disk_space=self.max_disk_space,
|
|
device=self.device,
|
|
device=self.device,
|
|
compression=self.compression,
|
|
compression=self.compression,
|
|
stats_report_interval=self.stats_report_interval,
|
|
stats_report_interval=self.stats_report_interval,
|
|
@@ -308,7 +316,8 @@ class ModuleContainer(threading.Thread):
|
|
min_batch_size: int,
|
|
min_batch_size: int,
|
|
max_batch_size: int,
|
|
max_batch_size: int,
|
|
torch_dtype: torch.dtype,
|
|
torch_dtype: torch.dtype,
|
|
- cache_dir: Optional[str],
|
|
|
|
|
|
+ cache_dir: str,
|
|
|
|
+ max_disk_space: int,
|
|
device: Union[str, torch.device],
|
|
device: Union[str, torch.device],
|
|
compression: CompressionType,
|
|
compression: CompressionType,
|
|
update_period: float,
|
|
update_period: float,
|
|
@@ -340,6 +349,7 @@ class ModuleContainer(threading.Thread):
|
|
torch_dtype=torch_dtype,
|
|
torch_dtype=torch_dtype,
|
|
use_auth_token=use_auth_token,
|
|
use_auth_token=use_auth_token,
|
|
cache_dir=cache_dir,
|
|
cache_dir=cache_dir,
|
|
|
|
+ max_disk_space=max_disk_space,
|
|
)
|
|
)
|
|
|
|
|
|
if load_in_8bit:
|
|
if load_in_8bit:
|