|
@@ -13,8 +13,10 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
from src import declare_active_modules, BloomConfig
|
|
|
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
|
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
|
|
|
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
+from src.dht_utils import get_remote_module_infos
|
|
|
from src.server.backend import TransformerBackend
|
|
|
+from src.server.block_selection import choose_best_blocks
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
@@ -32,19 +34,26 @@ class Server(threading.Thread):
|
|
|
*,
|
|
|
device: torch.device,
|
|
|
num_connection_handlers: int = 8,
|
|
|
+ throughput: float,
|
|
|
update_period: float = 30,
|
|
|
expiration: Optional[float] = None,
|
|
|
start: bool,
|
|
|
**kwargs,
|
|
|
):
|
|
|
threading.Thread.__init__(self)
|
|
|
- self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
|
|
+ self.dht, self.module_backends = dht, module_backends
|
|
|
+ self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
|
self.conn_handlers = [
|
|
|
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
]
|
|
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
- self.module_backends, dht, update_period, expiration, daemon=True
|
|
|
+ self.module_backends,
|
|
|
+ dht,
|
|
|
+ throughput=throughput,
|
|
|
+ update_period=update_period,
|
|
|
+ expiration=expiration,
|
|
|
+ daemon=True,
|
|
|
)
|
|
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
|
@@ -86,6 +95,7 @@ class Server(threading.Thread):
|
|
|
cls,
|
|
|
prefix: Optional[str],
|
|
|
converted_model_name_or_path: str,
|
|
|
+ throughput: float,
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
num_handlers: Optional[int] = None,
|
|
@@ -116,6 +126,9 @@ class Server(threading.Thread):
|
|
|
)
|
|
|
logger.info(f"Automatic dht prefix: {prefix}")
|
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
+ if expiration is None:
|
|
|
+ expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
|
+
|
|
|
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
@@ -127,6 +140,10 @@ class Server(threading.Thread):
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
|
+ block_config = BloomConfig.from_pretrained(
|
|
|
+ converted_model_name_or_path, use_auth_token=use_auth_token
|
|
|
+ )
|
|
|
+
|
|
|
if block_indices is not None:
|
|
|
try:
|
|
|
first_block_index, last_block_index = block_indices.split(":")
|
|
@@ -137,16 +154,22 @@ class Server(threading.Thread):
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
else:
|
|
|
assert num_blocks is not None
|
|
|
- block_indices = range(num_blocks) # TODO replace with proper load balancing
|
|
|
+ uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
|
|
|
+ module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
|
|
|
+ block_indices = choose_best_blocks(num_blocks, module_infos)
|
|
|
|
|
|
- block_config = BloomConfig.from_pretrained(
|
|
|
- converted_model_name_or_path, use_auth_token=use_auth_token
|
|
|
+ module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
|
+ declare_active_modules(
|
|
|
+ dht,
|
|
|
+ module_uids,
|
|
|
+ expiration_time=get_dht_time() + expiration,
|
|
|
+ state=ServerState.JOINING,
|
|
|
+ throughput=throughput,
|
|
|
)
|
|
|
+ logger.info(f"Announced that blocks {block_indices} are joining")
|
|
|
|
|
|
- # initialize modules
|
|
|
blocks = {}
|
|
|
- for block_index in block_indices:
|
|
|
- module_uid = f"{prefix}.{block_index}"
|
|
|
+ for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
block = load_pretrained_block(
|
|
|
converted_model_name_or_path,
|
|
|
block_index,
|
|
@@ -173,6 +196,7 @@ class Server(threading.Thread):
|
|
|
return cls(
|
|
|
dht,
|
|
|
blocks,
|
|
|
+ throughput=throughput,
|
|
|
num_connection_handlers=num_handlers,
|
|
|
device=device,
|
|
|
stats_report_interval=stats_report_interval,
|
|
@@ -209,6 +233,16 @@ class Server(threading.Thread):
|
|
|
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
|
|
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
|
"""
|
|
|
+ if self.module_backends:
|
|
|
+ declare_active_modules(
|
|
|
+ self.dht,
|
|
|
+ self.module_backends.keys(),
|
|
|
+ expiration_time=get_dht_time() + self.expiration,
|
|
|
+ state=ServerState.OFFLINE,
|
|
|
+ throughput=self.throughput,
|
|
|
+ )
|
|
|
+ logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
|
|
+
|
|
|
self.ready.clear()
|
|
|
|
|
|
for process in self.conn_handlers:
|
|
@@ -230,25 +264,38 @@ class Server(threading.Thread):
|
|
|
logger.debug(f"Shutting down runtime")
|
|
|
|
|
|
self.runtime.shutdown()
|
|
|
- logger.info("Server shutdown succesfully")
|
|
|
+ logger.info("Server shut down succesfully")
|
|
|
|
|
|
|
|
|
class ModuleAnnouncerThread(threading.Thread):
|
|
|
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
|
|
|
|
|
|
def __init__(
|
|
|
- self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
|
|
|
+ self,
|
|
|
+ module_backends: Dict[str, TransformerBackend],
|
|
|
+ dht: DHT,
|
|
|
+ *,
|
|
|
+ throughput: float,
|
|
|
+ update_period: float = 30,
|
|
|
+ expiration: float,
|
|
|
+ **kwargs
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
- if expiration is None:
|
|
|
- expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
|
self.module_backends = module_backends
|
|
|
self.dht = dht
|
|
|
+ self.throughput = throughput
|
|
|
self.update_period = update_period
|
|
|
self.expiration = expiration
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
def run(self) -> None:
|
|
|
- declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
|
|
|
- while not self.stop.wait(self.update_period):
|
|
|
- declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
|
|
|
+ while True:
|
|
|
+ declare_active_modules(
|
|
|
+ self.dht,
|
|
|
+ self.module_backends.keys(),
|
|
|
+ expiration_time=get_dht_time() + self.expiration,
|
|
|
+ state=ServerState.ONLINE,
|
|
|
+ throughput=self.throughput,
|
|
|
+ )
|
|
|
+ if self.stop.wait(self.update_period):
|
|
|
+ break
|