5
0
Эх сурвалжийг харах

Merge pull request #9 from learning-at-home/rpc

Rudimentary decentralization
justheuristic 3 жил өмнө
parent
commit
a28ea0aa6f

+ 5 - 1
cli/quantize_for_cpu.py → cli/quantize_cpu_naive.py

@@ -5,9 +5,10 @@ import os
 import psutil
 import torch.backends.quantized
 import transformers
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tqdm.auto import trange
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
@@ -47,3 +48,6 @@ if __name__ == "__main__":
             layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
         )
         torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
+
+    model.transformer.h = torch.nn.ModuleList()
+    torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))

+ 78 - 0
cli/run_server.py

@@ -0,0 +1,78 @@
+import os, sys
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  # add path to src
+
+import configargparse
+
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from src.server.server import Server
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+def main():
+    # fmt:off
+    parser = configargparse.ArgParser(default_config_files=["config.yml"])
+    parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
+
+    parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
+    parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
+    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
+    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
+
+    parser.add_argument('--num_handlers', type=int, default=None, required=False,
+                        help='server will use this many processes to handle incoming requests')
+    parser.add_argument('--min_batch_size', type=int, default=1,
+                        help='Minimum required batch size for all expert operations')
+    parser.add_argument('--max_batch_size', type=int, default=16384,
+                        help='The total number of examples in the same batch will not exceed this value')
+    parser.add_argument('--cache_size_bytes', type=int, default=None,
+                        help='The size of memory cache for storing past attention keys/values between inference steps')
+    parser.add_argument('--device', type=str, default=None, required=False,
+                        help='all experts will use this device in torch notation; default: cuda if available else cpu')
+
+    parser.add_argument('--update_period', type=float, required=False, default=30,
+                        help='Server will report experts to DHT once in this many seconds')
+    parser.add_argument('--expiration', type=float, required=False, default=None,
+                        help='DHT entries will expire after this many seconds')
+    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
+                        help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
+    parser.add_argument('--increase_file_limit', action='store_true',
+                        help='On *nix, this will increase the max number of processes '
+                             'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+    parser.add_argument('--stats_report_interval', type=int, required=False,
+                        help='Interval between two reports of batch processing performance statistics')
+
+    parser.add_argument('--custom_module_path', type=str, required=False,
+                        help='Path of a file with custom nn.modules, wrapped into special decorator')
+
+    # fmt:on
+    args = vars(parser.parse_args())
+    args.pop("config", None)
+
+    if args.pop("increase_file_limit"):
+        increase_file_limit()
+
+    compression_type = args.pop("compression")
+    compression = getattr(CompressionType, compression_type)
+
+    server = Server.create(**args, start=True, compression=compression)
+
+    try:
+        server.join()
+    except KeyboardInterrupt:
+        logger.info("Caught KeyboardInterrupt, shutting down")
+    finally:
+        server.shutdown()
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 0
src/__init__.py

@@ -0,0 +1 @@
+from .bloom import *

+ 1 - 0
src/bloom/__init__.py

@@ -0,0 +1 @@
+from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig

+ 6 - 5
src/block.py → src/bloom/block.py

@@ -9,13 +9,14 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.ops import (
+from src.bloom.ops import (
     BloomGelu,
     BloomScaledSoftmax,
     attention_mask_func,
     dropout_add,
     pre_process_alibi_for_pad,
-    split_tensor_along_last_dim, build_alibi_tensor,
+    split_tensor_along_last_dim,
+    build_alibi_tensor,
 )
 
 
@@ -204,12 +205,12 @@ class BloomMLP(nn.Module):
 class BloomBlock(nn.Module):
     def __init__(self, config, layer_number=None):
         super().__init__()
-        hidden_size = config.hidden_size
+        self.hidden_size = config.hidden_size
 
-        self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
         self.n_head = config.n_head
         self.self_attention = BloomAttention(config, layer_number=layer_number)
-        self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
 
         self.mlp = BloomMLP(config)
 

+ 6 - 4
src/model.py → src/bloom/model.py

@@ -8,6 +8,7 @@ from typing import Tuple
 
 import torch
 import torch.utils.checkpoint
+from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
 from transformers.file_utils import (
@@ -20,9 +21,10 @@ from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
 from transformers.utils import logging
 
-from src.block import BloomBlock
-from src.ops import build_alibi_tensor
+from src.bloom.block import BloomBlock
+from src.bloom.ops import build_alibi_tensor
 
+use_hivemind_log_handler("in_root_logger")
 logger = logging.get_logger(__name__)
 
 _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
@@ -30,7 +32,7 @@ _CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
 _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 
-class MemoryEfficientBloomConfig(_VanillaBloomConfig):
+class DistributedBloomConfig(_VanillaBloomConfig):
     compression: str = "none"
     slow_but_exact: bool = False
 
@@ -42,7 +44,7 @@ class BloomPreTrainedModel(PreTrainedModel):
     models.
     """
 
-    config_class = MemoryEfficientBloomConfig
+    config_class = DistributedBloomConfig
     base_model_prefix = "transformer"
     supports_gradient_checkpointing = True
     _no_split_modules = ["BloomBlock"]

+ 0 - 0
src/ops.py → src/bloom/ops.py


+ 1 - 0
src/client/__init__.py

@@ -0,0 +1 @@
+from src.client.remote_block import RemoteTransformerBlock

+ 57 - 0
src/client/remote_block.py

@@ -0,0 +1,57 @@
+from concurrent.futures import Future
+from functools import partial
+from typing import List, Optional, Union, Sequence
+
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertUID
+from hivemind.moe.server.dht_handler import _get_experts
+from hivemind.p2p import StubBase, P2P
+from hivemind.proto.runtime_pb2 import ExpertInfo
+from hivemind.dht import DHTExpiration, DHT
+from hivemind.utils import MPFuture
+from src.server.handler import TransformerConnectionHandler
+
+
+class RemoteTransformerBlock(RemoteExpert):
+    @property
+    def stub(self) -> StubBase:
+        return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
+
+
+def get_remote_module(
+    dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
+) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
+    """
+    :param uids: find experts with these ids from across the DHT
+    :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteTransformerBlock if found else None]
+    """
+    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    return create_remote_module(result, dht, return_future)
+
+
+def create_remote_module(
+    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
+
+
+def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
+    experts: List[Optional[RemoteTransformerBlock]] = []
+    for info in infos:
+        if info is not None:
+            experts.append(RemoteTransformerBlock(info, p2p))
+        else:
+            experts.append(None)
+    return experts

+ 0 - 0
src/server/__init__.py


+ 34 - 0
src/server/backend.py

@@ -0,0 +1,34 @@
+"""Code for serving bloom blocks via hivemind-server"""
+from typing import Tuple
+
+import torch
+from hivemind import BatchTensorDescriptor
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.task_pool import TaskPool
+
+from src.bloom.block import BloomBlock
+from src.server.cache import MemoryCache
+
+
+# TODO
+# BloomBackend serves a single layer
+# - ensure that parameters do not require grad!
+# - ensure that TaskPool for inference is NOT batched
+# - ensure that optimizer/scheduler is not created
+
+
+class BloomBlockBackend(ExpertBackend):
+    """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
+
+    def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
+        super().__init__(*args, **kwargs)  # to bypass super.__init__
+        self.memory_cache = memory_cache
+
+        for name, param in self.module.named_parameters():
+            assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+        for name, buf in self.module.named_buffers():
+            assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+
+    def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
+        with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
+            raise NotImplementedError("TODO")

+ 134 - 0
src/server/cache.py

@@ -0,0 +1,134 @@
+"""
+A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
+
+For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
+
+TODO In future, one could modify cache to implement, among other things,
+- in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
+-- note: this can be done using mp.Condtion
+- allocate cache as one contigous buffer to avoid fragmentation
+- quantize cached values using bitsandbytes
+- LRU offloading from gpu to ram
+
+"""
+import contextlib
+import ctypes
+import multiprocessing as mp
+import os
+from typing import Dict, Optional, Union
+
+import hivemind
+import torch
+from hivemind import use_hivemind_log_handler
+from hivemind.utils import TensorDescriptor, get_logger
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+Handle = int
+
+
+class MemoryCache:
+    """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
+
+    def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
+        self.device = device
+        self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+        self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
+        self._handle_counter = mp.Value(ctypes.c_uint64, 0, lock=False)
+        self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
+        self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
+        self.runtime_pid = os.getpid()
+
+        self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime
+        self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+
+    @property
+    def current_size_bytes(self) -> int:
+        return self._current_size.value
+
+    @current_size_bytes.setter
+    def current_size_bytes(self, value: int):
+        self._current_size.value = value
+
+    @property
+    def handle_counter(self) -> int:
+        return self._handle_counter.value
+
+    @handle_counter.setter
+    def handle_counter(self, value: int):
+        self._handle_counter.value = value
+
+    @contextlib.asynccontextmanager
+    async def allocate_cache(self, descr: TensorDescriptor) -> Handle:
+        """
+        Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
+
+        :param descr: allocate a tensor of this size, dtype, etc
+
+        :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
+        Furthermore, it can be called concurrently with at most one use_cache call in runtime.
+        """
+        assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
+        assert descr.device is None and descr
+        allocated_handle = None
+        allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        try:
+            async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+                    raise AllocationFailed(
+                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
+                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
+                    )
+
+                allocated_handle = int(self.handle_counter)
+                self.current_size_bytes += allocated_size_bytes
+                self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                self._pending_messages.value += 1
+                self._pipe_send.send((allocated_handle, descr))
+
+            yield allocated_handle
+        finally:
+            if allocated_handle is not None:
+                async with hivemind.utils.enter_asynchronously(self.lock_metadata):
+                    self._pending_messages.value += 1
+                    self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
+                    self.current_size_bytes -= allocated_size_bytes
+
+    @contextlib.contextmanager
+    def use_cache(self, handle: Handle) -> torch.Tensor:
+        """
+        Return a tensor that was previously allocated with try_allocate_cache,
+
+        :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
+        However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
+        """
+        assert os.getpid() == self.runtime_pid
+        # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
+
+        with self.lock_metadata:
+            if self._allocated_tensors is None:
+                self._allocated_tensors = {}
+
+            # read creation/deletion requests from connection handlers
+            for i in range(int(self._pending_messages.value)):
+                recv_handle, recv_data = self._pipe_recv.recv()
+                self._pending_messages.value -= 1
+                if isinstance(recv_data, TensorDescriptor):
+                    self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
+                elif recv_data is None:
+                    if recv_handle not in self._allocated_tensors:
+                        logger.warning(
+                            f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
+                        )
+                    self._allocated_tensors.pop(recv_handle, None)
+                else:
+                    logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
+
+        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+        yield self._allocated_tensors[handle]
+
+
+class AllocationFailed(Exception):
+    pass

+ 24 - 0
src/server/handler.py

@@ -0,0 +1,24 @@
+from typing import AsyncIterator, Dict
+
+from hivemind import P2PContext, DHT
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.proto import runtime_pb2
+
+from src.bloom.block import BloomBlock
+
+
+class TransformerConnectionHandler(ConnectionHandler):
+    """Handles three request types: forward, backward and forward-incremental (inference)"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    async def rpc_forward_incremental(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        # note: you may use self.experts[uid].memory_cache!
+        # encode expert_uid as @model_name[starting_layer:finishing_layer]
+        # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
+        # - receive and maintain a handle for attention cache here
+
+        raise NotImplementedError()

+ 192 - 0
src/server/server.py

@@ -0,0 +1,192 @@
+from __future__ import annotations
+import threading
+from typing import Optional, Dict, Union, Sequence
+
+import torch
+from hivemind import DHT, BatchTensorDescriptor
+from hivemind.moe.server.dht_handler import DHTHandlerThread
+from hivemind.moe.server.layers import add_custom_models_from_file
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+import multiprocessing as mp
+
+from src import DistributedBloomConfig
+from src.bloom.block import BloomBlock
+from src.server.cache import MemoryCache
+from src.server.backend import BloomBlockBackend
+from src.server.handler import TransformerConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class Server(threading.Thread):
+    """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
+
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, BloomBlockBackend],
+        *,
+        device: torch.device,
+        num_connection_handlers: int = 8,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
+        start: bool,
+        **kwargs,
+    ):
+        threading.Thread.__init__(self)
+        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
+        self.conn_handlers = [
+            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+        ]
+        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
+        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    def run(self):
+        """
+        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Serving {len(self.module_backends)} blocks:")
+        for expert_name, backend in self.module_backends.items():
+            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.module_backends:
+            self.dht_handler_thread.start()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.result()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
+    # noinspection PyMethodOverriding
+    @classmethod
+    def create(
+        cls,
+        num_blocks: int,
+        block_config: str,
+        num_handlers: Optional[int] = None,
+        min_batch_size: int = 1,
+        max_batch_size: int = 4096,
+        cache_size_bytes: Optional[int] = None,
+        device: Union[str, torch.device] = None,
+        initial_peers: Sequence[str] = (),
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
+        *,
+        start: bool,
+        **kwargs,
+    ) -> Server:
+        """Create a server with one or more bloom blocks. See run_server.py for documentation."""
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+
+        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+        block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
+        memory_cache = MemoryCache(device, cache_size_bytes)
+        # initialize modules
+        blocks = {}
+        for i in range(num_blocks):
+            module_uid = f"dummy_block.{i}"
+            block = BloomBlock(block_config, layer_number=i)
+            for param in block.parameters():
+                param.requires_grad = False
+
+            blocks[module_uid] = BloomBlockBackend(
+                module_uid,
+                block,
+                memory_cache=memory_cache,
+                args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                kwargs_schema={},
+                outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
+
+        return cls(
+            dht,
+            blocks,
+            num_connection_handlers=num_handlers,
+            device=device,
+            stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
+            start=start,
+        )
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts Server in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
+
+        for process in self.conn_handlers:
+            process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.module_backends:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        self.dht.shutdown()
+        self.dht.join()
+
+        logger.debug(f"Shutting down runtime")
+
+        self.runtime.shutdown()
+        logger.info("Server shutdown succesfully")