|
@@ -4,7 +4,7 @@ import multiprocessing as mp
|
|
import random
|
|
import random
|
|
import threading
|
|
import threading
|
|
import time
|
|
import time
|
|
-from typing import Dict, Optional, Sequence, Union
|
|
|
|
|
|
+from typing import Dict, Optional, List, Sequence, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
@@ -29,76 +29,14 @@ logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
class Server(threading.Thread):
|
|
class Server(threading.Thread):
|
|
- """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
|
|
|
|
|
+ """
|
|
|
|
+ Runs Server, periodically checks that the network is balanced,
|
|
|
|
+ restarts the Server with other layers if the imbalance is significant
|
|
|
|
+ """
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- dht: DHT,
|
|
|
|
- module_backends: Dict[str, TransformerBackend],
|
|
|
|
- *,
|
|
|
|
- inference_max_length: int,
|
|
|
|
- num_connection_handlers: int = 8,
|
|
|
|
- throughput: float,
|
|
|
|
- update_period: float = 30,
|
|
|
|
- expiration: Optional[float] = None,
|
|
|
|
- start: bool,
|
|
|
|
- **kwargs,
|
|
|
|
- ):
|
|
|
|
- threading.Thread.__init__(self)
|
|
|
|
- self.dht, self.module_backends = dht, module_backends
|
|
|
|
- self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
|
|
- self.conn_handlers = [
|
|
|
|
- TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
|
|
|
|
- for _ in range(num_connection_handlers)
|
|
|
|
- ]
|
|
|
|
- self.runtime = Runtime(self.module_backends, **kwargs)
|
|
|
|
- self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
|
- self.module_backends,
|
|
|
|
- dht,
|
|
|
|
- throughput=throughput,
|
|
|
|
- update_period=update_period,
|
|
|
|
- expiration=expiration,
|
|
|
|
- daemon=True,
|
|
|
|
- )
|
|
|
|
- self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
-
|
|
|
|
- if start:
|
|
|
|
- self.run_in_background(await_ready=True)
|
|
|
|
-
|
|
|
|
- def run(self):
|
|
|
|
- """
|
|
|
|
- Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
|
|
- runs Runtime (self.runtime) to process incoming requests.
|
|
|
|
- """
|
|
|
|
- logger.info(f"Serving {len(self.module_backends)} blocks:")
|
|
|
|
- for block_name, backend in self.module_backends.items():
|
|
|
|
- num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
|
|
|
|
- parameter_msg = f"{num_parameters} trainable parameters" if num_parameters else "frozen"
|
|
|
|
- logger.info(f"{block_name}: {backend.module.__class__.__name__}, {parameter_msg}")
|
|
|
|
-
|
|
|
|
- if not self.dht.is_alive():
|
|
|
|
- self.dht.run_in_background(await_ready=True)
|
|
|
|
|
|
|
|
- if self.module_backends:
|
|
|
|
- self.dht_handler_thread.start()
|
|
|
|
-
|
|
|
|
- if self.checkpoint_saver is not None:
|
|
|
|
- self.checkpoint_saver.start()
|
|
|
|
-
|
|
|
|
- for process in self.conn_handlers:
|
|
|
|
- if not process.is_alive():
|
|
|
|
- process.start()
|
|
|
|
- process.ready.result()
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- self.runtime.run()
|
|
|
|
- finally:
|
|
|
|
- self.shutdown()
|
|
|
|
-
|
|
|
|
- # noinspection PyMethodOverriding
|
|
|
|
- @classmethod
|
|
|
|
- def create(
|
|
|
|
- cls,
|
|
|
|
prefix: Optional[str],
|
|
prefix: Optional[str],
|
|
converted_model_name_or_path: str,
|
|
converted_model_name_or_path: str,
|
|
throughput: Union[float, str],
|
|
throughput: Union[float, str],
|
|
@@ -127,10 +65,26 @@ class Server(threading.Thread):
|
|
*,
|
|
*,
|
|
start: bool,
|
|
start: bool,
|
|
**kwargs,
|
|
**kwargs,
|
|
- ) -> Server:
|
|
|
|
|
|
+ ):
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
+
|
|
|
|
+ super().__init__()
|
|
|
|
+
|
|
|
|
+ self.converted_model_name_or_path = converted_model_name_or_path
|
|
|
|
+ self.num_handlers = num_handlers
|
|
|
|
+ self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
|
+ self.inference_max_length = inference_max_length
|
|
|
|
+ self.cache_dir = cache_dir
|
|
|
|
+ self.attn_cache_size = attn_cache_size
|
|
|
|
+ self.compression = compression
|
|
|
|
+ self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
|
|
+ self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
|
|
+ self.use_auth_token = use_auth_token
|
|
|
|
+ self.load_in_8bit = load_in_8bit
|
|
|
|
+
|
|
if custom_module_path is not None:
|
|
if custom_module_path is not None:
|
|
add_custom_models_from_file(custom_module_path)
|
|
add_custom_models_from_file(custom_module_path)
|
|
|
|
+
|
|
if prefix is None:
|
|
if prefix is None:
|
|
prefix = converted_model_name_or_path
|
|
prefix = converted_model_name_or_path
|
|
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
|
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
|
@@ -138,27 +92,37 @@ class Server(threading.Thread):
|
|
f"Please specify --prefix manually when starting a server"
|
|
f"Please specify --prefix manually when starting a server"
|
|
)
|
|
)
|
|
logger.info(f"Automatic dht prefix: {prefix}")
|
|
logger.info(f"Automatic dht prefix: {prefix}")
|
|
|
|
+ self.prefix = prefix
|
|
|
|
+
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
|
+
|
|
if expiration is None:
|
|
if expiration is None:
|
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
|
|
+ self.expiration = expiration
|
|
|
|
|
|
- dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
|
- visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|
|
|
|
|
|
+ self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
|
+ visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
- memory_cache = MemoryCache(device, attn_cache_size)
|
|
|
|
|
|
+ self.device = device
|
|
|
|
+
|
|
|
|
+ self.memory_cache = MemoryCache(device, attn_cache_size)
|
|
|
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
if throughput in ["auto", "eval"]:
|
|
if throughput in ["auto", "eval"]:
|
|
throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
|
|
throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
|
|
|
|
+ self.throughput = throughput
|
|
|
|
|
|
if isinstance(torch_dtype, str):
|
|
if isinstance(torch_dtype, str):
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
+ self.torch_dtype = torch_dtype
|
|
|
|
|
|
- block_config = BloomConfig.from_pretrained(
|
|
|
|
- converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
|
|
|
|
|
|
+ self.block_config = BloomConfig.from_pretrained(
|
|
|
|
+ converted_model_name_or_path,
|
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
|
+ revision=revision,
|
|
)
|
|
)
|
|
|
|
|
|
if block_indices is not None:
|
|
if block_indices is not None:
|
|
@@ -175,10 +139,148 @@ class Server(threading.Thread):
|
|
time.sleep(random.random() * max_block_selection_delay)
|
|
time.sleep(random.random() * max_block_selection_delay)
|
|
|
|
|
|
assert num_blocks is not None
|
|
assert num_blocks is not None
|
|
- uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
|
|
|
|
- module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
|
|
|
|
|
|
+ uids = [f"{prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
+ module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
|
|
block_indices = choose_best_blocks(num_blocks, module_infos)
|
|
block_indices = choose_best_blocks(num_blocks, module_infos)
|
|
|
|
+ self.block_indices = block_indices
|
|
|
|
+
|
|
|
|
+ self.stop = threading.Event()
|
|
|
|
+ if start:
|
|
|
|
+ self.start()
|
|
|
|
+
|
|
|
|
+ def run(self):
|
|
|
|
+ self.module_container = ModuleContainer.create(
|
|
|
|
+ dht=self.dht,
|
|
|
|
+ prefix=self.prefix,
|
|
|
|
+ converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
|
+ block_config=self.block_config,
|
|
|
|
+ memory_cache=self.memory_cache,
|
|
|
|
+ throughput=self.throughput,
|
|
|
|
+ block_indices=self.block_indices,
|
|
|
|
+ num_handlers=self.num_handlers,
|
|
|
|
+ min_batch_size=self.min_batch_size,
|
|
|
|
+ max_batch_size=self.max_batch_size,
|
|
|
|
+ inference_max_length=self.inference_max_length,
|
|
|
|
+ torch_dtype=self.torch_dtype,
|
|
|
|
+ cache_dir=self.cache_dir,
|
|
|
|
+ device=self.device,
|
|
|
|
+ compression=self.compression,
|
|
|
|
+ stats_report_interval=self.stats_report_interval,
|
|
|
|
+ update_period=self.update_period,
|
|
|
|
+ expiration=self.expiration,
|
|
|
|
+ prefetch_batches=self.prefetch_batches,
|
|
|
|
+ sender_threads=self.sender_threads,
|
|
|
|
+ use_auth_token=self.use_auth_token,
|
|
|
|
+ load_in_8bit=self.load_in_8bit,
|
|
|
|
+ start=True,
|
|
|
|
+ )
|
|
|
|
+ try:
|
|
|
|
+ self.stop.wait()
|
|
|
|
+ finally:
|
|
|
|
+ self.module_container.shutdown()
|
|
|
|
+
|
|
|
|
+ def shutdown(self):
|
|
|
|
+ self.stop.set()
|
|
|
|
+
|
|
|
|
+ self.dht.shutdown()
|
|
|
|
+ self.dht.join()
|
|
|
|
|
|
|
|
+
|
|
|
|
+class ModuleContainer(threading.Thread):
|
|
|
|
+ """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
|
|
|
|
+
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ dht: DHT,
|
|
|
|
+ module_backends: Dict[str, TransformerBackend],
|
|
|
|
+ *,
|
|
|
|
+ device: torch.device,
|
|
|
|
+ num_connection_handlers: int,
|
|
|
|
+ throughput: float,
|
|
|
|
+ update_period: float,
|
|
|
|
+ expiration: Optional[float] = None,
|
|
|
|
+ start: bool,
|
|
|
|
+ **kwargs,
|
|
|
|
+ ):
|
|
|
|
+ super().__init__()
|
|
|
|
+
|
|
|
|
+ self.dht, self.module_backends = dht, module_backends
|
|
|
|
+ self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
|
|
+ self.conn_handlers = [
|
|
|
|
+ TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
|
+ ]
|
|
|
|
+ self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
|
+ self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
|
+ self.module_backends,
|
|
|
|
+ dht,
|
|
|
|
+ throughput=throughput,
|
|
|
|
+ update_period=update_period,
|
|
|
|
+ expiration=expiration,
|
|
|
|
+ daemon=True,
|
|
|
|
+ )
|
|
|
|
+ self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
+
|
|
|
|
+ if start:
|
|
|
|
+ self.run_in_background(await_ready=True)
|
|
|
|
+
|
|
|
|
+ def run(self):
|
|
|
|
+ """
|
|
|
|
+ Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
|
|
+ runs Runtime (self.runtime) to process incoming requests.
|
|
|
|
+ """
|
|
|
|
+ logger.info(f"Serving {len(self.module_backends)} blocks:")
|
|
|
|
+ for expert_name, backend in self.module_backends.items():
|
|
|
|
+ num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
|
|
|
|
+ logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
|
|
|
|
+
|
|
|
|
+ if not self.dht.is_alive():
|
|
|
|
+ self.dht.run_in_background(await_ready=True)
|
|
|
|
+
|
|
|
|
+ if self.module_backends:
|
|
|
|
+ self.dht_handler_thread.start()
|
|
|
|
+
|
|
|
|
+ if self.checkpoint_saver is not None:
|
|
|
|
+ self.checkpoint_saver.start()
|
|
|
|
+
|
|
|
|
+ for process in self.conn_handlers:
|
|
|
|
+ if not process.is_alive():
|
|
|
|
+ process.start()
|
|
|
|
+ process.ready.result()
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ self.runtime.run()
|
|
|
|
+ finally:
|
|
|
|
+ self.shutdown()
|
|
|
|
+
|
|
|
|
+ # noinspection PyMethodOverriding
|
|
|
|
+ @classmethod
|
|
|
|
+ def create(
|
|
|
|
+ cls,
|
|
|
|
+ *,
|
|
|
|
+ dht: DHT,
|
|
|
|
+ prefix: str,
|
|
|
|
+ converted_model_name_or_path: str,
|
|
|
|
+ block_config: BloomConfig,
|
|
|
|
+ memory_cache: MemoryCache,
|
|
|
|
+ throughput: float,
|
|
|
|
+ block_indices: List[int],
|
|
|
|
+ num_handlers: Optional[int],
|
|
|
|
+ min_batch_size: int,
|
|
|
|
+ max_batch_size: int,
|
|
|
|
+ inference_max_length: int,
|
|
|
|
+ torch_dtype: torch.dtype,
|
|
|
|
+ cache_dir: Optional[str],
|
|
|
|
+ device: Union[str, torch.device],
|
|
|
|
+ compression: CompressionType,
|
|
|
|
+ stats_report_interval: Optional[int],
|
|
|
|
+ update_period: float,
|
|
|
|
+ expiration: Optional[float],
|
|
|
|
+ prefetch_batches: int,
|
|
|
|
+ sender_threads: int,
|
|
|
|
+ use_auth_token: Optional[str],
|
|
|
|
+ load_in_8bit: bool,
|
|
|
|
+ start: bool,
|
|
|
|
+ ) -> ModuleContainer:
|
|
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
declare_active_modules(
|
|
declare_active_modules(
|
|
dht,
|
|
dht,
|
|
@@ -245,33 +347,36 @@ class Server(threading.Thread):
|
|
|
|
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
"""
|
|
"""
|
|
- Starts Server in a background thread. if await_ready, this method will wait until background server
|
|
|
|
|
|
+ Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
|
|
is ready to process incoming requests or for :timeout: seconds max.
|
|
is ready to process incoming requests or for :timeout: seconds max.
|
|
"""
|
|
"""
|
|
self.start()
|
|
self.start()
|
|
if await_ready and not self.ready.wait(timeout=timeout):
|
|
if await_ready and not self.ready.wait(timeout=timeout):
|
|
- raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
|
|
|
|
|
|
+ raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
|
|
|
|
|
|
@property
|
|
@property
|
|
def ready(self) -> mp.synchronize.Event:
|
|
def ready(self) -> mp.synchronize.Event:
|
|
"""
|
|
"""
|
|
- An event (multiprocessing.Event) that is set when the server is ready to process requests.
|
|
|
|
|
|
+ An event (multiprocessing.Event) that is set when the container is ready to process requests.
|
|
|
|
|
|
Example
|
|
Example
|
|
=======
|
|
=======
|
|
- >>> server.start()
|
|
|
|
- >>> server.ready.wait(timeout=10)
|
|
|
|
- >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
|
|
|
|
|
|
+ >>> container.start()
|
|
|
|
+ >>> container.ready.wait(timeout=10)
|
|
|
|
+ >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
|
|
"""
|
|
"""
|
|
return self.runtime.ready # mp.Event that is true if self is ready to process batches
|
|
return self.runtime.ready # mp.Event that is true if self is ready to process batches
|
|
|
|
|
|
def shutdown(self):
|
|
def shutdown(self):
|
|
"""
|
|
"""
|
|
- Gracefully terminate the server, process-safe.
|
|
|
|
- Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
|
|
|
|
|
|
+ Gracefully terminate the container, process-safe.
|
|
|
|
+ Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
"""
|
|
"""
|
|
if self.module_backends:
|
|
if self.module_backends:
|
|
|
|
+ self.dht_handler_thread.stop.set()
|
|
|
|
+ self.dht_handler_thread.join()
|
|
|
|
+
|
|
declare_active_modules(
|
|
declare_active_modules(
|
|
self.dht,
|
|
self.dht,
|
|
self.module_backends.keys(),
|
|
self.module_backends.keys(),
|
|
@@ -288,25 +393,18 @@ class Server(threading.Thread):
|
|
process.join()
|
|
process.join()
|
|
logger.debug("Connection handlers terminated")
|
|
logger.debug("Connection handlers terminated")
|
|
|
|
|
|
- if self.module_backends:
|
|
|
|
- self.dht_handler_thread.stop.set()
|
|
|
|
- self.dht_handler_thread.join()
|
|
|
|
-
|
|
|
|
if self.checkpoint_saver is not None:
|
|
if self.checkpoint_saver is not None:
|
|
self.checkpoint_saver.stop.set()
|
|
self.checkpoint_saver.stop.set()
|
|
self.checkpoint_saver.join()
|
|
self.checkpoint_saver.join()
|
|
|
|
|
|
- self.dht.shutdown()
|
|
|
|
- self.dht.join()
|
|
|
|
-
|
|
|
|
logger.debug(f"Shutting down runtime")
|
|
logger.debug(f"Shutting down runtime")
|
|
-
|
|
|
|
self.runtime.shutdown()
|
|
self.runtime.shutdown()
|
|
- logger.info("Server shut down succesfully")
|
|
|
|
|
|
+
|
|
|
|
+ logger.info("Module container shut down succesfully")
|
|
|
|
|
|
|
|
|
|
class ModuleAnnouncerThread(threading.Thread):
|
|
class ModuleAnnouncerThread(threading.Thread):
|
|
- """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
|
|
|
|
|
|
+ """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|