|
@@ -2,23 +2,22 @@ from __future__ import annotations
|
|
|
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.synchronize
|
|
|
-import random
|
|
|
import threading
|
|
|
from contextlib import contextmanager
|
|
|
from functools import partial
|
|
|
-from typing import Dict, Optional, Tuple, List
|
|
|
+from typing import Dict, Optional, Tuple
|
|
|
from pathlib import Path
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind.dht import DHT
|
|
|
-from hivemind.server.expert_uid import UID_DELIMITER
|
|
|
-from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
|
|
|
+from hivemind.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
|
|
|
+from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_directory
|
|
|
from hivemind.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
|
|
|
from hivemind.server.expert_backend import ExpertBackend
|
|
|
-from hivemind.server.layers import name_to_block, name_to_input
|
|
|
+from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
|
|
|
from hivemind.server.runtime import Runtime
|
|
|
from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
|
|
|
from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
|
|
@@ -68,11 +67,12 @@ class Server(threading.Thread):
|
|
|
if start:
|
|
|
self.run_in_background(await_ready=True)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
|
|
|
- expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
|
|
|
- device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
|
|
|
- load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
+ @classmethod
|
|
|
+ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
|
|
|
+ expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
|
|
|
+ num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
|
|
|
+ no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
|
|
|
+ compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
"""
|
|
|
Instantiate a server with several identical experts. See argparse comments below for details
|
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
@@ -85,7 +85,12 @@ class Server(threading.Thread):
|
|
|
:param num_handlers: server will use this many parallel processes to handle incoming requests
|
|
|
:param max_batch_size: total num examples in the same batch will not exceed this value
|
|
|
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
|
|
|
+
|
|
|
:param optim_cls: uses this optimizer to train all experts
|
|
|
+ :param scheduler: if not `none`, the name of the expert LR scheduler
|
|
|
+ :param num_warmup_steps: the number of warmup steps for LR schedule
|
|
|
+ :param num_training_steps: the total number of steps for LR schedule
|
|
|
+
|
|
|
:param no_dht: if specified, the server will not be attached to a dht
|
|
|
:param initial_peers: a list of peers that will introduce this node to the dht,\
|
|
|
e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
|
|
@@ -93,8 +98,7 @@ class Server(threading.Thread):
|
|
|
:param dht_port: DHT node will listen on this port, default = find open port
|
|
|
You can then use this node as initial peer for subsequent servers.
|
|
|
|
|
|
- :param checkpoint_dir: directory to save expert checkpoints
|
|
|
- :param load_experts: whether to load expert checkpoints from checkpoint_dir
|
|
|
+ :param checkpoint_dir: directory to save and load expert checkpoints
|
|
|
|
|
|
:param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
|
|
|
hosted on this server. For a more fine-grained compression, start server in python and specify compression
|
|
@@ -113,23 +117,29 @@ class Server(threading.Thread):
|
|
|
dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
|
|
|
logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
|
|
|
|
|
|
- if load_experts:
|
|
|
- assert dir_is_correct(checkpoint_dir)
|
|
|
- assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
|
|
|
- expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
|
|
|
- if expert_uids:
|
|
|
- logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
|
|
|
- else:
|
|
|
- logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
|
|
|
-
|
|
|
- assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
|
|
|
- "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
|
|
|
+ assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
|
|
|
+ (num_experts is not None and expert_uids is None)), \
|
|
|
+ "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
|
|
|
|
|
|
- # get expert uids if not loaded previously
|
|
|
if expert_uids is None:
|
|
|
- assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
|
|
|
- logger.info(f"Generating expert uids from pattern {expert_pattern}")
|
|
|
- expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
|
|
|
+ if checkpoint_dir is not None:
|
|
|
+ assert is_directory(checkpoint_dir)
|
|
|
+ expert_uids = [child.name for child in checkpoint_dir.iterdir() if
|
|
|
+ (child / 'checkpoint_last.pt').exists()]
|
|
|
+ total_experts_in_checkpoint = len(expert_uids)
|
|
|
+ logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
|
|
|
+
|
|
|
+ if total_experts_in_checkpoint > num_experts:
|
|
|
+ raise ValueError(
|
|
|
+ f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
|
|
|
+ f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
|
|
|
+ else:
|
|
|
+ expert_uids = []
|
|
|
+
|
|
|
+ uids_to_generate = num_experts - len(expert_uids)
|
|
|
+ if uids_to_generate > 0:
|
|
|
+ logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
|
|
|
+ expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
|
|
|
|
|
|
num_experts = len(expert_uids)
|
|
|
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
|
|
@@ -142,6 +152,8 @@ class Server(threading.Thread):
|
|
|
else:
|
|
|
args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
|
|
|
|
|
|
+ scheduler = schedule_name_to_scheduler[scheduler]
|
|
|
+
|
|
|
# initialize experts
|
|
|
experts = {}
|
|
|
for expert_uid in expert_uids:
|
|
@@ -150,15 +162,17 @@ class Server(threading.Thread):
|
|
|
args_schema=args_schema,
|
|
|
outputs_schema=hivemind.BatchTensorDescriptor(
|
|
|
hidden_dim, compression=compression),
|
|
|
- opt=optim_cls(expert.parameters()),
|
|
|
+ optimizer=optim_cls(expert.parameters()),
|
|
|
+ scheduler=scheduler,
|
|
|
+ num_warmup_steps=num_warmup_steps,
|
|
|
+ num_training_steps=num_training_steps,
|
|
|
max_batch_size=max_batch_size)
|
|
|
|
|
|
- if load_experts:
|
|
|
- load_weights(experts, checkpoint_dir)
|
|
|
+ if checkpoint_dir is not None:
|
|
|
+ load_experts(experts, checkpoint_dir)
|
|
|
|
|
|
- server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
|
|
|
- start=start)
|
|
|
- return server
|
|
|
+ return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
|
|
|
+ checkpoint_dir=checkpoint_dir, start=start)
|
|
|
|
|
|
def run(self):
|
|
|
"""
|
|
@@ -241,7 +255,7 @@ class Server(threading.Thread):
|
|
|
def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
|
|
|
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
|
pipe, runners_pipe = mp.Pipe(duplex=True)
|
|
|
- runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
|
|
|
+ runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
|
|
|
|
|
|
try:
|
|
|
runner.start()
|
|
@@ -269,63 +283,3 @@ def _server_runner(pipe, *args, **kwargs):
|
|
|
server.shutdown()
|
|
|
server.join()
|
|
|
logger.info("Server shut down.")
|
|
|
-
|
|
|
-
|
|
|
-def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
|
|
|
- attempts_per_expert=10) -> List[str]:
|
|
|
- """
|
|
|
- Sample experts from a given pattern, remove duplicates.
|
|
|
- :param num_experts: sample this many unique expert uids
|
|
|
- :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
|
|
|
- means "sample random experts between myprefix.0.0 and myprefix.255.255;
|
|
|
- :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
|
|
|
- :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
|
|
|
- :note: this method is not strictly process-safe. If several servers run it concurrently, they have
|
|
|
- a small chance of sampling duplicate expert uids.
|
|
|
- """
|
|
|
- remaining_attempts = attempts_per_expert * num_experts
|
|
|
- found_uids, attempted_uids = list(), set()
|
|
|
-
|
|
|
- def _generate_uid():
|
|
|
- if expert_pattern is None:
|
|
|
- return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
|
|
|
-
|
|
|
- uid = []
|
|
|
- for block in expert_pattern.split(UID_DELIMITER):
|
|
|
- try:
|
|
|
- if '[' not in block and ']' not in block:
|
|
|
- uid.append(block)
|
|
|
- elif block.startswith('[') and block.endswith(']') and ':' in block:
|
|
|
- slice_start, slice_end = map(int, block[1:-1].split(':'))
|
|
|
- uid.append(str(random.randint(slice_start, slice_end - 1)))
|
|
|
- else:
|
|
|
- raise ValueError("Block must be either fixed or a range [from:to]")
|
|
|
- except KeyboardInterrupt as e:
|
|
|
- raise e
|
|
|
- except Exception as e:
|
|
|
- raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
|
|
|
- return UID_DELIMITER.join(uid)
|
|
|
-
|
|
|
- while remaining_attempts > 0 and len(found_uids) < num_experts:
|
|
|
-
|
|
|
- # 1. sample new expert uids at random
|
|
|
- new_uids = []
|
|
|
- while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
|
|
|
- new_uid = _generate_uid()
|
|
|
- remaining_attempts -= 1
|
|
|
- if new_uid not in attempted_uids:
|
|
|
- attempted_uids.add(new_uid)
|
|
|
- new_uids.append(new_uid)
|
|
|
-
|
|
|
- # 2. look into DHT (if given) and remove duplicates
|
|
|
- if dht:
|
|
|
- existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
|
|
|
- if found_expert is not None}
|
|
|
- new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
|
|
|
-
|
|
|
- found_uids += new_uids
|
|
|
-
|
|
|
- if len(found_uids) != num_experts:
|
|
|
- logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
|
|
|
- f"{attempts_per_expert * num_experts} attempts")
|
|
|
- return found_uids
|