|
@@ -83,7 +83,7 @@ class Server:
|
|
|
quant_type: Optional[QuantType] = None,
|
|
|
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
|
|
skip_reachability_check: bool = False,
|
|
|
- dht_client_mode: Optional[bool] = None,
|
|
|
+ reachable_via_relay: Optional[bool] = None,
|
|
|
use_relay: bool = True,
|
|
|
use_auto_relay: bool = True,
|
|
|
adapters: Sequence[str] = (),
|
|
@@ -129,20 +129,20 @@ class Server:
|
|
|
for block_index in range(self.block_config.num_hidden_layers)
|
|
|
]
|
|
|
|
|
|
- if dht_client_mode is None:
|
|
|
+ if reachable_via_relay is None:
|
|
|
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
|
|
|
- dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
|
|
|
- logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}")
|
|
|
+ reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
|
|
|
+ logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
|
|
|
self.dht = DHT(
|
|
|
initial_peers=initial_peers,
|
|
|
start=True,
|
|
|
num_workers=self.block_config.num_hidden_layers,
|
|
|
use_relay=use_relay,
|
|
|
use_auto_relay=use_auto_relay,
|
|
|
- client_mode=dht_client_mode,
|
|
|
+ client_mode=reachable_via_relay,
|
|
|
**kwargs,
|
|
|
)
|
|
|
- self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
|
|
|
+ self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
|
|
|
|
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
@@ -227,6 +227,7 @@ class Server:
|
|
|
num_blocks=num_blocks,
|
|
|
quant_type=quant_type,
|
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
+ reachable_via_relay=reachable_via_relay,
|
|
|
force_eval=(throughput == "eval"),
|
|
|
cache_dir=cache_dir,
|
|
|
)
|
|
@@ -239,7 +240,7 @@ class Server:
|
|
|
adapters=tuple(adapters),
|
|
|
torch_dtype=str(torch_dtype).replace("torch.", ""),
|
|
|
quant_type=quant_type.name.lower(),
|
|
|
- using_relay=self.dht.client_mode,
|
|
|
+ using_relay=reachable_via_relay,
|
|
|
**throughput_info,
|
|
|
)
|
|
|
|