|
@@ -69,7 +69,7 @@ class Server(threading.Thread):
|
|
|
@staticmethod
|
|
|
def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
|
|
|
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
|
|
|
- device=None, no_dht=False, initial_peers=(), dht_port=None, verbose=True,
|
|
|
+ device=None, no_dht=False, initial_peers=(), dht_port=None,
|
|
|
compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
"""
|
|
|
Instantiate a server with several identical experts. See argparse comments below for details
|
|
@@ -91,31 +91,29 @@ class Server(threading.Thread):
|
|
|
:param dht_port: DHT node will listen on this port, default = find open port
|
|
|
You can then use this node as initial peer for subsequent servers.
|
|
|
|
|
|
- :param verbose: whether to print server started / finished / terminated events
|
|
|
:param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
|
|
|
hosted on this server. For a more fine-grained compression, start server in python and specify compression
|
|
|
for each BatchTensorProto in ExpertBackend for the respective experts.
|
|
|
|
|
|
:param start: if True, starts server right away and returns when server is ready for requests
|
|
|
"""
|
|
|
- if verbose and len(kwargs) != 0:
|
|
|
- print("Ignored kwargs:", kwargs)
|
|
|
+ if len(kwargs) != 0:
|
|
|
+ logger.info("Ignored kwargs:", kwargs)
|
|
|
assert expert_cls in name_to_block
|
|
|
+ assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
|
|
|
+ "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
|
|
|
|
|
|
- # initialize dht
|
|
|
- dht = None
|
|
|
- if not no_dht:
|
|
|
- logger.info(f"Bootstrapping DHT node, initial peers = {initial_peers}")
|
|
|
+ if no_dht:
|
|
|
+ dht = None
|
|
|
+ else:
|
|
|
dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
|
|
|
dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
|
|
|
- if verbose:
|
|
|
- logger.info(f"Running dht node on port {dht.port}")
|
|
|
+ logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
|
|
|
|
|
|
# get expert uids
|
|
|
- assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
|
|
|
- "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
|
|
|
if expert_uids is None:
|
|
|
assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
|
|
|
+ logger.info(f"Generating expert uids from pattern {expert_pattern}")
|
|
|
expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
|
|
|
|
|
|
num_experts = len(expert_uids)
|
|
@@ -130,7 +128,6 @@ class Server(threading.Thread):
|
|
|
args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
|
|
|
|
|
|
# initialize experts
|
|
|
-
|
|
|
experts = {}
|
|
|
for expert_uid in expert_uids:
|
|
|
expert = name_to_block[expert_cls](hidden_dim)
|
|
@@ -139,18 +136,10 @@ class Server(threading.Thread):
|
|
|
outputs_schema=hivemind.BatchTensorDescriptor(
|
|
|
hidden_dim, compression=compression),
|
|
|
opt=optim_cls(expert.parameters()),
|
|
|
- max_batch_size=max_batch_size,
|
|
|
- )
|
|
|
- # actually start server
|
|
|
- server = Server(
|
|
|
- dht, experts, listen_on=listen_on,
|
|
|
- num_connection_handlers=num_handlers, device=device)
|
|
|
+ max_batch_size=max_batch_size)
|
|
|
|
|
|
- if start:
|
|
|
- server.run_in_background(await_ready=True)
|
|
|
- if verbose:
|
|
|
- logger.info(f"Server started at {server.listen_on}")
|
|
|
- logger.info(f"Got {len(experts)} active experts of type {expert_cls}: {list(experts.keys())}")
|
|
|
+ server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
|
|
|
+ start=start)
|
|
|
return server
|
|
|
|
|
|
def run(self):
|
|
@@ -158,6 +147,12 @@ class Server(threading.Thread):
|
|
|
Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
|
runs Runtime (self.runtime) to process incoming requests.
|
|
|
"""
|
|
|
+ logger.info(f"Server started at {self.listen_on}")
|
|
|
+ logger.info(f"Got {len(self.experts)} experts:")
|
|
|
+ for expert_name, backend in self.experts.items():
|
|
|
+ num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
|
|
|
+ logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
|
|
|
+
|
|
|
if self.dht:
|
|
|
if not self.dht.is_alive():
|
|
|
self.dht.run_in_background(await_ready=True)
|
|
@@ -172,8 +167,6 @@ class Server(threading.Thread):
|
|
|
for process in self.conn_handlers:
|
|
|
if not process.is_alive():
|
|
|
process.start()
|
|
|
-
|
|
|
- for process in self.conn_handlers:
|
|
|
process.ready.wait()
|
|
|
|
|
|
self.runtime.run()
|
|
@@ -227,11 +220,10 @@ class Server(threading.Thread):
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
-def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
|
|
|
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
|
|
|
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
|
pipe, runners_pipe = mp.Pipe(duplex=True)
|
|
|
- runner = mp.get_context("spawn").Process(
|
|
|
- target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
|
|
|
+ runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
|
|
|
|
|
|
try:
|
|
|
runner.start()
|
|
@@ -240,15 +232,13 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tupl
|
|
|
finally:
|
|
|
runner.join(timeout=shutdown_timeout)
|
|
|
if runner.is_alive():
|
|
|
- if verbose:
|
|
|
- logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
|
|
|
+ logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
|
|
|
runner.kill()
|
|
|
- if verbose:
|
|
|
- logger.info("Server terminated.")
|
|
|
+ logger.info("Server terminated.")
|
|
|
|
|
|
|
|
|
-def _server_runner(pipe, *args, verbose, **kwargs):
|
|
|
- server = Server.create(*args, verbose=verbose, start=True, **kwargs)
|
|
|
+def _server_runner(pipe, *args, **kwargs):
|
|
|
+ server = Server.create(*args, start=True, **kwargs)
|
|
|
try:
|
|
|
if server.dht is not None:
|
|
|
dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
|
|
@@ -257,12 +247,10 @@ def _server_runner(pipe, *args, verbose, **kwargs):
|
|
|
pipe.send((server.listen_on, dht_listen_on))
|
|
|
pipe.recv() # wait for shutdown signal
|
|
|
finally:
|
|
|
- if verbose:
|
|
|
- logger.info("Shutting down server...")
|
|
|
+ logger.info("Shutting down server...")
|
|
|
server.shutdown()
|
|
|
server.join()
|
|
|
- if verbose:
|
|
|
- logger.info("Server shut down successfully.")
|
|
|
+ logger.info("Server shut down.")
|
|
|
|
|
|
|
|
|
def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
|
|
@@ -277,7 +265,6 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
|
|
|
:note: this method is not strictly process-safe. If several servers run it concurrently, they have
|
|
|
a small chance of sampling duplicate expert uids.
|
|
|
"""
|
|
|
- logger.info("Generating expert uids...")
|
|
|
remaining_attempts = attempts_per_expert * num_experts
|
|
|
found_uids, attempted_uids = list(), set()
|
|
|
|
|
@@ -298,7 +285,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
|
|
|
except KeyboardInterrupt as e:
|
|
|
raise e
|
|
|
except Exception as e:
|
|
|
- raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block} , {e}")
|
|
|
+ raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
|
|
|
return hivemind.dht.UID_DELIMITER.join(uid)
|
|
|
|
|
|
while remaining_attempts > 0 and len(found_uids) < num_experts:
|