Explorar o código

Implement block selection on servers (#20)

Alexander Borzunov %!s(int64=3) %!d(string=hai) anos
pai
achega
aba43f1308

+ 2 - 2
README.md

@@ -43,7 +43,7 @@ python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
 # - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
 # - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
-# - each server except first should have --initial_peers pointing to one of pre-existing servers 
+# - each server except first should have --initial_peers pointing to one of pre-existing servers
 ```
 ```
 
 
 Then open a python notebook or console and run:
 Then open a python notebook or console and run:
@@ -66,7 +66,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 loss.backward()
 
 
 # test inference, one block
 # test inference, one block
-with layer3.begin_inference_session() as sess:
+with layer3.inference_session() as sess:
     for i in range(10):
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
         res = sess.step(torch.ones(1, 1, 4096))
 ```
 ```

+ 2 - 0
cli/run_server.py

@@ -41,6 +41,8 @@ def main():
                         help="Use this dtype to store block weights and do computations. "
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
                              "By default, respect the dtypes in the pre-trained state dict.")
 
 
+    parser.add_argument('--throughput', type=float, default=1.0,
+                        help='Expected server throughput')
     parser.add_argument('--update_period', type=float, required=False, default=30,
     parser.add_argument('--update_period', type=float, required=False, default=30,
                         help='Server will report experts to DHT once in this many seconds')
                         help='Server will report experts to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 1 - 1
src/client/remote_block.py

@@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids)))  # TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
         super().__init__(peer_info, p2p)
         super().__init__(peer_info, p2p)
 
 
     @property
     @property

+ 16 - 17
src/client/remote_sequence_info.py

@@ -1,15 +1,13 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import dataclasses
 import threading
 import threading
-from functools import partial
 from typing import List, NamedTuple, Optional, Sequence, Tuple
 from typing import List, NamedTuple, Optional, Sequence, Tuple
 
 
 from hivemind import DHT, PeerID
 from hivemind import DHT, PeerID
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from src.data_structures import ModuleUID, RemoteModuleInfo
-from src.dht_utils import _get_remote_module_infos
+from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
+from src.dht_utils import get_remote_module_infos
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -18,21 +16,20 @@ logger = get_logger(__file__)
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 
 
 
 
-@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
 class RemoteSequenceInfo:
 class RemoteSequenceInfo:
     """Keeps and updates the meta-information about which peers host which blocks"""
     """Keeps and updates the meta-information about which peers host which blocks"""
 
 
     dht: DHT
     dht: DHT
-    block_uids: List[ModuleUID, ...]
-    block_infos: List[Optional[RemoteModuleInfo], ...]
+    block_uids: List[ModuleUID]
+    block_infos: List[Optional[RemoteModuleInfo]]
     spans_by_priority: List[Span]  # sorted from best to worst
     spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span], ...]
+    spans_containing_block: Tuple[List[Span]]
     lock_changes: threading.Lock
     lock_changes: threading.Lock
 
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
         self.dht = dht
         self.dht = dht
         self.block_uids = list(block_uids)
         self.block_uids = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.block_infos = [None] * len(self.block_uids)
         self.spans_by_priority = []
         self.spans_by_priority = []
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.lock_changes = threading.Lock()
         self.lock_changes = threading.Lock()
@@ -48,21 +45,17 @@ class RemoteSequenceInfo:
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
 
 
     def update_block_infos_(self):
     def update_block_infos_(self):
-        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
-        )
+        new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
         assert len(new_block_infos) == len(self.block_uids)
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
                 logger.warning(f"Found no block info for block {uid}")
             if not isinstance(info, RemoteModuleInfo):
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.peer_ids:
+            if not info.servers:
                 logger.warning(f"Found no active peers for block {uid}")
                 logger.warning(f"Found no active peers for block {uid}")
             if info.uid != uid:
             if info.uid != uid:
                 logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
                 logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            if not isinstance(info.peer_ids, set):
-                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
             self.block_infos[block_index] = info
             self.block_infos[block_index] = info
 
 
     @staticmethod
     @staticmethod
@@ -70,14 +63,20 @@ class RemoteSequenceInfo:
         closed_spans = []
         closed_spans = []
         active_spans = {}
         active_spans = {}
         for block_index, info in enumerate(block_infos):
         for block_index, info in enumerate(block_infos):
-            for peer_id in info.peer_ids:
+            for peer_id, server in info.servers.items():
+                if server.state != ServerState.ONLINE:
+                    continue
                 if peer_id not in active_spans:
                 if peer_id not in active_spans:
                     active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                     active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                 else:  # peer_id in active_spans
                 else:  # peer_id in active_spans
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
 
 
             for peer_id in list(active_spans.keys()):
             for peer_id in list(active_spans.keys()):
-                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                if (
+                    peer_id not in info.servers or
+                    info.servers[peer_id].state != ServerState.ONLINE or
+                    block_index == len(block_infos) - 1
+                ):
                     closed_spans.append(active_spans.pop(peer_id))
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans
         assert not active_spans
 
 

+ 21 - 2
src/data_structures.py

@@ -1,8 +1,27 @@
-from typing import Collection, NamedTuple
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict
 
 
 from hivemind import PeerID
 from hivemind import PeerID
 
 
 ModuleUID = str
 ModuleUID = str
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
-RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])
+
+
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+@dataclass
+class ServerInfo:
+    state: ServerState
+    throughput: float
+
+
+@dataclass
+class RemoteModuleInfo:
+    uid: ModuleUID
+    servers: Dict[PeerID, ServerInfo]

+ 40 - 13
src/dht_utils.py

@@ -12,7 +12,7 @@ from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 
 import src
 import src
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -22,7 +22,8 @@ def declare_active_modules(
     dht: DHT,
     dht: DHT,
     uids: Sequence[ModuleUID],
     uids: Sequence[ModuleUID],
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
+    state: ServerState,
+    throughput: float,
     wait: bool = True,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
     """
@@ -30,7 +31,7 @@ def declare_active_modules(
 
 
     :param uids: a list of module ids to declare
     :param uids: a list of module ids to declare
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param throughput: optionally specify your performance in terms of compute throughput
+    :param throughput: specify your performance in terms of compute throughput
     :param expiration_time: declated modules will be visible for this many seconds
     :param expiration_time: declated modules will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     """
@@ -41,7 +42,13 @@ def declare_active_modules(
     for uid in uids:
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
+        partial(
+            _declare_active_modules,
+            uids=uids,
+            expiration_time=expiration_time,
+            state=state,
+            throughput=throughput,
+        ),
         return_future=not wait,
         return_future=not wait,
     )
     )
 
 
@@ -51,13 +58,14 @@ async def _declare_active_modules(
     node: DHTNode,
     node: DHTNode,
     uids: List[ModuleUID],
     uids: List[ModuleUID],
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
+    state: ServerState,
+    throughput: float,
 ) -> Dict[ModuleUID, bool]:
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
     return await node.store_many(
         keys=uids,
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[throughput] * len(uids),
+        values=[(state.value, throughput)] * len(uids),
         expiration_time=expiration_time,
         expiration_time=expiration_time,
         num_workers=num_workers,
         num_workers=num_workers,
     )
     )
@@ -94,6 +102,19 @@ def get_remote_module(
     return modules[0] if single_uid else modules
     return modules[0] if single_uid else modules
 
 
 
 
+def get_remote_module_infos(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    expiration_time: Optional[DHTExpiration] = None,
+) -> List[Optional[RemoteModuleInfo]]:
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
+    )
+    return infos[0] if single_uid else infos
+
+
 async def _get_remote_module_infos(
 async def _get_remote_module_infos(
     dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
     dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
 ) -> List[Optional[RemoteModuleInfo]]:
 ) -> List[Optional[RemoteModuleInfo]]:
@@ -109,14 +130,20 @@ async def _get_remote_module_infos(
             if metadata is not None:
             if metadata is not None:
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
             continue
-        valid_entries = set()
-        for maybe_peer_id, _unused_value in metadata.value.items():
+        servers = {}
+        for peer_id, server_info in metadata.value.items():
             try:
             try:
-                valid_entries.add(PeerID.from_base58(maybe_peer_id))
-            except:
-                logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
-        if valid_entries:
-            modules[i] = RemoteModuleInfo(uid, valid_entries)
+                peer_id = PeerID.from_base58(peer_id)
+                server_info = server_info.value
+                if not (isinstance(server_info, tuple) and len(server_info) == 2 and
+                        isinstance(server_info[0], int) and isinstance(server_info[1], float)):
+                    raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
+                state, throughput = server_info
+                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+            except (TypeError, ValueError) as e:
+                logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+        if servers:
+            modules[i] = RemoteModuleInfo(uid, servers)
     return modules
     return modules
 
 
 
 

+ 18 - 0
src/server/block_selection.py

@@ -0,0 +1,18 @@
+from typing import List, Optional
+
+from src.data_structures import RemoteModuleInfo, ServerState
+
+
+def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
+    throughputs = []
+    for module in remote_module_infos:
+        if module is None:
+            throughputs.append(0)
+            continue
+        throughputs.append(sum(server.throughput for server in module.servers.values()
+                               if server.state != ServerState.OFFLINE))
+
+    options = [(throughputs[i:i + num_blocks], i)
+               for i in range(0, len(throughputs) - num_blocks + 1)]
+    best_start = min(options)[1]
+    return list(range(best_start, best_start + num_blocks))

+ 63 - 16
src/server/server.py

@@ -13,8 +13,10 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
 from src import declare_active_modules, BloomConfig
 from src import declare_active_modules, BloomConfig
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from src.dht_utils import get_remote_module_infos
 from src.server.backend import TransformerBackend
 from src.server.backend import TransformerBackend
+from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
 
 
@@ -32,19 +34,26 @@ class Server(threading.Thread):
         *,
         *,
         device: torch.device,
         device: torch.device,
         num_connection_handlers: int = 8,
         num_connection_handlers: int = 8,
+        throughput: float,
         update_period: float = 30,
         update_period: float = 30,
         expiration: Optional[float] = None,
         expiration: Optional[float] = None,
         start: bool,
         start: bool,
         **kwargs,
         **kwargs,
     ):
     ):
         threading.Thread.__init__(self)
         threading.Thread.__init__(self)
-        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
+        self.dht, self.module_backends = dht, module_backends
+        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
         ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
         self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends, dht, update_period, expiration, daemon=True
+            self.module_backends,
+            dht,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
         )
         )
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
 
@@ -86,6 +95,7 @@ class Server(threading.Thread):
         cls,
         cls,
         prefix: Optional[str],
         prefix: Optional[str],
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
+        throughput: float,
         num_blocks: Optional[int] = None,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
         num_handlers: Optional[int] = None,
@@ -116,6 +126,9 @@ class Server(threading.Thread):
             )
             )
             logger.info(f"Automatic dht prefix: {prefix}")
             logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@@ -127,6 +140,10 @@ class Server(threading.Thread):
             torch_dtype = DTYPE_MAP[torch_dtype]
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
 
+        block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token
+        )
+
         if block_indices is not None:
         if block_indices is not None:
             try:
             try:
                 first_block_index, last_block_index = block_indices.split(":")
                 first_block_index, last_block_index = block_indices.split(":")
@@ -137,16 +154,22 @@ class Server(threading.Thread):
             block_indices = range(first_block_index, last_block_index)
             block_indices = range(first_block_index, last_block_index)
         else:
         else:
             assert num_blocks is not None
             assert num_blocks is not None
-            block_indices = range(num_blocks)  # TODO replace with proper load balancing
+            uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
+            module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
+            block_indices = choose_best_blocks(num_blocks, module_infos)
 
 
-        block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        declare_active_modules(
+            dht,
+            module_uids,
+            expiration_time=get_dht_time() + expiration,
+            state=ServerState.JOINING,
+            throughput=throughput,
         )
         )
+        logger.info(f"Announced that blocks {block_indices} are joining")
 
 
-        # initialize modules
         blocks = {}
         blocks = {}
-        for block_index in block_indices:
-            module_uid = f"{prefix}.{block_index}"
+        for module_uid, block_index in zip(module_uids, block_indices):
             block = load_pretrained_block(
             block = load_pretrained_block(
                 converted_model_name_or_path,
                 converted_model_name_or_path,
                 block_index,
                 block_index,
@@ -173,6 +196,7 @@ class Server(threading.Thread):
         return cls(
         return cls(
             dht,
             dht,
             blocks,
             blocks,
+            throughput=throughput,
             num_connection_handlers=num_handlers,
             num_connection_handlers=num_handlers,
             device=device,
             device=device,
             stats_report_interval=stats_report_interval,
             stats_report_interval=stats_report_interval,
@@ -209,6 +233,16 @@ class Server(threading.Thread):
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
+        if self.module_backends:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.OFFLINE,
+                throughput=self.throughput,
+            )
+            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+
         self.ready.clear()
         self.ready.clear()
 
 
         for process in self.conn_handlers:
         for process in self.conn_handlers:
@@ -230,25 +264,38 @@ class Server(threading.Thread):
         logger.debug(f"Shutting down runtime")
         logger.debug(f"Shutting down runtime")
 
 
         self.runtime.shutdown()
         self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
+        logger.info("Server shut down succesfully")
 
 
 
 
 class ModuleAnnouncerThread(threading.Thread):
 class ModuleAnnouncerThread(threading.Thread):
     """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
     """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
 
 
     def __init__(
     def __init__(
-        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+        self,
+        module_backends: Dict[str, TransformerBackend],
+        dht: DHT,
+        *,
+        throughput: float,
+        update_period: float = 30,
+        expiration: float,
+        **kwargs
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        if expiration is None:
-            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.module_backends = module_backends
         self.module_backends = module_backends
         self.dht = dht
         self.dht = dht
+        self.throughput = throughput
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
-        while not self.stop.wait(self.update_period):
-            declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
+        while True:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.ONLINE,
+                throughput=self.throughput,
+            )
+            if self.stop.wait(self.update_period):
+                break