|
@@ -7,12 +7,13 @@ import threading
|
|
|
from contextlib import contextmanager
|
|
|
from functools import partial
|
|
|
from typing import Dict, Optional, Tuple, List
|
|
|
+from pathlib import Path
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind.dht import DHT
|
|
|
-from hivemind.server.checkpoint_saver import CheckpointSaver
|
|
|
+from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
|
|
|
from hivemind.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.server.dht_handler import DHTHandlerThread
|
|
|
from hivemind.server.expert_backend import ExpertBackend
|
|
@@ -69,8 +70,8 @@ class Server(threading.Thread):
|
|
|
@staticmethod
|
|
|
def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
|
|
|
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
|
|
|
- device=None, no_dht=False, initial_peers=(), dht_port=None,
|
|
|
- compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
+ device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
|
|
|
+ load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
"""
|
|
|
Instantiate a server with several identical experts. See argparse comments below for details
|
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
@@ -91,6 +92,9 @@ class Server(threading.Thread):
|
|
|
:param dht_port: DHT node will listen on this port, default = find open port
|
|
|
You can then use this node as initial peer for subsequent servers.
|
|
|
|
|
|
+ :param checkpoint_dir: directory to save expert checkpoints
|
|
|
+ :param load_experts: whether to load expert checkpoints from checkpoint_dir
|
|
|
+
|
|
|
:param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
|
|
|
hosted on this server. For a more fine-grained compression, start server in python and specify compression
|
|
|
for each BatchTensorProto in ExpertBackend for the respective experts.
|
|
@@ -100,8 +104,6 @@ class Server(threading.Thread):
|
|
|
if len(kwargs) != 0:
|
|
|
logger.info("Ignored kwargs:", kwargs)
|
|
|
assert expert_cls in name_to_block
|
|
|
- assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
|
|
|
- "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
|
|
|
|
|
|
if no_dht:
|
|
|
dht = None
|
|
@@ -110,7 +112,19 @@ class Server(threading.Thread):
|
|
|
dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
|
|
|
logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
|
|
|
|
|
|
- # get expert uids
|
|
|
+ if load_experts:
|
|
|
+ assert dir_is_correct(checkpoint_dir)
|
|
|
+ assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
|
|
|
+ expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
|
|
|
+ if expert_uids:
|
|
|
+ logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
|
|
|
+ else:
|
|
|
+ logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
|
|
|
+
|
|
|
+ assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
|
|
|
+ "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
|
|
|
+
|
|
|
+ # get expert uids if not loaded previously
|
|
|
if expert_uids is None:
|
|
|
assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
|
|
|
logger.info(f"Generating expert uids from pattern {expert_pattern}")
|
|
@@ -138,6 +152,9 @@ class Server(threading.Thread):
|
|
|
opt=optim_cls(expert.parameters()),
|
|
|
max_batch_size=max_batch_size)
|
|
|
|
|
|
+ if load_experts:
|
|
|
+ load_weights(experts, checkpoint_dir)
|
|
|
+
|
|
|
server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
|
|
|
start=start)
|
|
|
return server
|