|
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
|
|
import random
|
|
|
import threading
|
|
|
import time
|
|
|
-from typing import Dict, List, Optional, Sequence, Union
|
|
|
+from typing import Dict, List, Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import psutil
|
|
@@ -18,6 +18,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
from src import BloomConfig, declare_active_modules
|
|
|
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
|
+from src.constants import PUBLIC_INITIAL_PEERS
|
|
|
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
from src.dht_utils import get_remote_module_infos
|
|
|
from src.server import block_selection
|
|
@@ -104,7 +105,10 @@ class Server(threading.Thread):
|
|
|
|
|
|
self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
|
- logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
+ if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
|
+ logger.info("Connecting to the public Petals swarm")
|
|
|
+ else:
|
|
|
+ logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.device = device
|