|
@@ -1,43 +1,75 @@
|
|
|
+from __future__ import annotations
|
|
|
import threading
|
|
|
from typing import Optional, Dict, Union, Sequence
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import Server, DHT
|
|
|
+from hivemind import DHT, BatchTensorDescriptor
|
|
|
from hivemind.moe.server.dht_handler import DHTHandlerThread
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
|
|
|
+import multiprocessing as mp
|
|
|
|
|
|
from src import DistributedBloomConfig
|
|
|
from src.bloom.block import BloomBlock
|
|
|
from src.server.cache import MemoryCache
|
|
|
-from src.server.backend import BloomBlockBackend
|
|
|
+from src.server.backend import TransformerBlockBackend
|
|
|
from src.server.handler import BloomConnectionHandler
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-class BloomServer(Server):
|
|
|
+class Server(threading.Thread):
|
|
|
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
|
|
def __init__(
|
|
|
- self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
|
|
|
- device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
|
|
|
- cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
|
|
|
+ self, dht: DHT, module_backends: Dict[str, TransformerBlockBackend], *,
|
|
|
+ device: torch.device, num_connection_handlers: int = 8,
|
|
|
+ update_period: float = 30, expiration: Optional[float] = None,
|
|
|
+ start: bool, **kwargs
|
|
|
):
|
|
|
threading.Thread.__init__(self)
|
|
|
- self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
|
|
|
-
|
|
|
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
|
|
- self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
|
|
|
+ self.conn_handlers = [
|
|
|
+ BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
+ ]
|
|
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
- self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
|
|
|
+ self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
|
|
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
|
|
if start:
|
|
|
self.run_in_background(await_ready=True)
|
|
|
|
|
|
+ def run(self):
|
|
|
+ """
|
|
|
+ Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
|
+ runs Runtime (self.runtime) to process incoming requests.
|
|
|
+ """
|
|
|
+ logger.info(f"Serving {len(self.module_backends)} blocks:")
|
|
|
+ for expert_name, backend in self.module_backends.items():
|
|
|
+ num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
|
|
|
+ logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
|
|
|
+
|
|
|
+ if not self.dht.is_alive():
|
|
|
+ self.dht.run_in_background(await_ready=True)
|
|
|
+
|
|
|
+ if self.module_backends:
|
|
|
+ self.dht_handler_thread.start()
|
|
|
+
|
|
|
+ if self.checkpoint_saver is not None:
|
|
|
+ self.checkpoint_saver.start()
|
|
|
+
|
|
|
+ for process in self.conn_handlers:
|
|
|
+ if not process.is_alive():
|
|
|
+ process.start()
|
|
|
+ process.ready.result()
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.runtime.run()
|
|
|
+ finally:
|
|
|
+ self.shutdown()
|
|
|
+
|
|
|
# noinspection PyMethodOverriding
|
|
|
@classmethod
|
|
|
def create(
|
|
@@ -69,40 +101,84 @@ class BloomServer(Server):
|
|
|
|
|
|
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
- if isinstance(block_config, str):
|
|
|
- block_config = DistributedBloomConfig
|
|
|
+ block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
|
|
|
|
|
# initialize modules
|
|
|
- module_backends = {}
|
|
|
- for i in range(len(module_backends)):
|
|
|
+ blocks = {}
|
|
|
+ for i in range(num_blocks):
|
|
|
module_uid = f"dummy_block.{i}"
|
|
|
- block = BloomBlock(block_config, layer_number=i)
|
|
|
- #TODO run the actual model
|
|
|
-
|
|
|
- module_backends[module_uid] = BloomBlockBackend(
|
|
|
- name=expert_uid,
|
|
|
- expert=block,
|
|
|
- args_schema=args_schema,
|
|
|
- num_warmup_steps=num_warmup_steps,
|
|
|
- num_total_steps=num_total_steps,
|
|
|
- clip_grad_norm=clip_grad_norm,
|
|
|
+ HARDCODCED_LENGTH = 2048
|
|
|
+
|
|
|
+ blocks[module_uid] = TransformerBlockBackend(
|
|
|
+ module_uid,
|
|
|
+ BloomBlock(block_config, layer_number=i),
|
|
|
+ args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
|
|
|
+ kwargs_schema={},
|
|
|
+ outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
|
|
|
min_batch_size=min_batch_size,
|
|
|
max_batch_size=max_batch_size,
|
|
|
)
|
|
|
|
|
|
- if checkpoint_dir is not None:
|
|
|
- load_experts(experts, checkpoint_dir)
|
|
|
-
|
|
|
return cls(
|
|
|
dht,
|
|
|
- experts,
|
|
|
+ blocks,
|
|
|
cache_size_bytes=cache_size_bytes,
|
|
|
num_connection_handlers=num_handlers,
|
|
|
device=device,
|
|
|
- checkpoint_dir=checkpoint_dir,
|
|
|
stats_report_interval=stats_report_interval,
|
|
|
update_period=update_period,
|
|
|
expiration=expiration,
|
|
|
start=start,
|
|
|
)
|
|
|
|
|
|
+ def run_in_background(self, await_ready=True, timeout=None):
|
|
|
+ """
|
|
|
+ Starts Server in a background thread. if await_ready, this method will wait until background server
|
|
|
+ is ready to process incoming requests or for :timeout: seconds max.
|
|
|
+ """
|
|
|
+ self.start()
|
|
|
+ if await_ready and not self.ready.wait(timeout=timeout):
|
|
|
+ raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
|
|
|
+
|
|
|
+ @property
|
|
|
+ def ready(self) -> mp.synchronize.Event:
|
|
|
+ """
|
|
|
+ An event (multiprocessing.Event) that is set when the server is ready to process requests.
|
|
|
+
|
|
|
+ Example
|
|
|
+ =======
|
|
|
+ >>> server.start()
|
|
|
+ >>> server.ready.wait(timeout=10)
|
|
|
+ >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
|
|
|
+ """
|
|
|
+ return self.runtime.ready # mp.Event that is true if self is ready to process batches
|
|
|
+
|
|
|
+ def shutdown(self):
|
|
|
+ """
|
|
|
+ Gracefully terminate the server, process-safe.
|
|
|
+ Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
|
|
|
+ If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
|
+ """
|
|
|
+ self.ready.clear()
|
|
|
+
|
|
|
+ for process in self.conn_handlers:
|
|
|
+ process.terminate()
|
|
|
+ process.join()
|
|
|
+ logger.debug("Connection handlers terminated")
|
|
|
+
|
|
|
+ if self.module_backends:
|
|
|
+ self.dht_handler_thread.stop.set()
|
|
|
+ self.dht_handler_thread.join()
|
|
|
+
|
|
|
+ if self.checkpoint_saver is not None:
|
|
|
+ self.checkpoint_saver.stop.set()
|
|
|
+ self.checkpoint_saver.join()
|
|
|
+
|
|
|
+ self.dht.shutdown()
|
|
|
+ self.dht.join()
|
|
|
+
|
|
|
+ logger.debug(f"Shutting down runtime")
|
|
|
+
|
|
|
+ self.runtime.shutdown()
|
|
|
+ logger.info("Server shutdown succesfully")
|
|
|
+
|