Browse Source

Implement block selection on servers (#20)

Alexander Borzunov 3 years ago
parent
commit
aba43f1308

+ 2 - 2
README.md

@@ -43,7 +43,7 @@ python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
 # - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
-# - each server except first should have --initial_peers pointing to one of pre-existing servers 
+# - each server except first should have --initial_peers pointing to one of pre-existing servers
 ```
 
 Then open a python notebook or console and run:
@@ -66,7 +66,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer3.begin_inference_session() as sess:
+with layer3.inference_session() as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```

+ 2 - 0
cli/run_server.py

@@ -41,6 +41,8 @@ def main():
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
 
+    parser.add_argument('--throughput', type=float, default=1.0,
+                        help='Expected server throughput')
     parser.add_argument('--update_period', type=float, required=False, default=30,
                         help='Server will report experts to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 1 - 1
src/client/remote_block.py

@@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids)))  # TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
         super().__init__(peer_info, p2p)
 
     @property

+ 16 - 17
src/client/remote_sequence_info.py

@@ -1,15 +1,13 @@
 from __future__ import annotations
 
-import dataclasses
 import threading
-from functools import partial
 from typing import List, NamedTuple, Optional, Sequence, Tuple
 
 from hivemind import DHT, PeerID
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.data_structures import ModuleUID, RemoteModuleInfo
-from src.dht_utils import _get_remote_module_infos
+from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
+from src.dht_utils import get_remote_module_infos
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -18,21 +16,20 @@ logger = get_logger(__file__)
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 
 
-@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
 class RemoteSequenceInfo:
     """Keeps and updates the meta-information about which peers host which blocks"""
 
     dht: DHT
-    block_uids: List[ModuleUID, ...]
-    block_infos: List[Optional[RemoteModuleInfo], ...]
+    block_uids: List[ModuleUID]
+    block_infos: List[Optional[RemoteModuleInfo]]
     spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span], ...]
+    spans_containing_block: Tuple[List[Span]]
     lock_changes: threading.Lock
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
         self.dht = dht
         self.block_uids = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.block_infos = [None] * len(self.block_uids)
         self.spans_by_priority = []
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.lock_changes = threading.Lock()
@@ -48,21 +45,17 @@ class RemoteSequenceInfo:
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
 
     def update_block_infos_(self):
-        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
-        )
+        new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.peer_ids:
+            if not info.servers:
                 logger.warning(f"Found no active peers for block {uid}")
             if info.uid != uid:
                 logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            if not isinstance(info.peer_ids, set):
-                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
             self.block_infos[block_index] = info
 
     @staticmethod
@@ -70,14 +63,20 @@ class RemoteSequenceInfo:
         closed_spans = []
         active_spans = {}
         for block_index, info in enumerate(block_infos):
-            for peer_id in info.peer_ids:
+            for peer_id, server in info.servers.items():
+                if server.state != ServerState.ONLINE:
+                    continue
                 if peer_id not in active_spans:
                     active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                 else:  # peer_id in active_spans
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
 
             for peer_id in list(active_spans.keys()):
-                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                if (
+                    peer_id not in info.servers or
+                    info.servers[peer_id].state != ServerState.ONLINE or
+                    block_index == len(block_infos) - 1
+                ):
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans
 

+ 21 - 2
src/data_structures.py

@@ -1,8 +1,27 @@
-from typing import Collection, NamedTuple
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict
 
 from hivemind import PeerID
 
 ModuleUID = str
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
-RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])
+
+
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+@dataclass
+class ServerInfo:
+    state: ServerState
+    throughput: float
+
+
+@dataclass
+class RemoteModuleInfo:
+    uid: ModuleUID
+    servers: Dict[PeerID, ServerInfo]

+ 40 - 13
src/dht_utils.py

@@ -12,7 +12,7 @@ from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -22,7 +22,8 @@ def declare_active_modules(
     dht: DHT,
     uids: Sequence[ModuleUID],
     expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
+    state: ServerState,
+    throughput: float,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
@@ -30,7 +31,7 @@ def declare_active_modules(
 
     :param uids: a list of module ids to declare
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param throughput: optionally specify your performance in terms of compute throughput
+    :param throughput: specify your performance in terms of compute throughput
     :param expiration_time: declated modules will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
@@ -41,7 +42,13 @@ def declare_active_modules(
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
     return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
+        partial(
+            _declare_active_modules,
+            uids=uids,
+            expiration_time=expiration_time,
+            state=state,
+            throughput=throughput,
+        ),
         return_future=not wait,
     )
 
@@ -51,13 +58,14 @@ async def _declare_active_modules(
     node: DHTNode,
     uids: List[ModuleUID],
     expiration_time: DHTExpiration,
-    throughput: Optional[float] = None,
+    state: ServerState,
+    throughput: float,
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[throughput] * len(uids),
+        values=[(state.value, throughput)] * len(uids),
         expiration_time=expiration_time,
         num_workers=num_workers,
     )
@@ -94,6 +102,19 @@ def get_remote_module(
     return modules[0] if single_uid else modules
 
 
+def get_remote_module_infos(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    expiration_time: Optional[DHTExpiration] = None,
+) -> List[Optional[RemoteModuleInfo]]:
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
+    )
+    return infos[0] if single_uid else infos
+
+
 async def _get_remote_module_infos(
     dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
 ) -> List[Optional[RemoteModuleInfo]]:
@@ -109,14 +130,20 @@ async def _get_remote_module_infos(
             if metadata is not None:
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
-        valid_entries = set()
-        for maybe_peer_id, _unused_value in metadata.value.items():
+        servers = {}
+        for peer_id, server_info in metadata.value.items():
             try:
-                valid_entries.add(PeerID.from_base58(maybe_peer_id))
-            except:
-                logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
-        if valid_entries:
-            modules[i] = RemoteModuleInfo(uid, valid_entries)
+                peer_id = PeerID.from_base58(peer_id)
+                server_info = server_info.value
+                if not (isinstance(server_info, tuple) and len(server_info) == 2 and
+                        isinstance(server_info[0], int) and isinstance(server_info[1], float)):
+                    raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
+                state, throughput = server_info
+                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+            except (TypeError, ValueError) as e:
+                logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+        if servers:
+            modules[i] = RemoteModuleInfo(uid, servers)
     return modules
 
 

+ 18 - 0
src/server/block_selection.py

@@ -0,0 +1,18 @@
+from typing import List, Optional
+
+from src.data_structures import RemoteModuleInfo, ServerState
+
+
+def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
+    throughputs = []
+    for module in remote_module_infos:
+        if module is None:
+            throughputs.append(0)
+            continue
+        throughputs.append(sum(server.throughput for server in module.servers.values()
+                               if server.state != ServerState.OFFLINE))
+
+    options = [(throughputs[i:i + num_blocks], i)
+               for i in range(0, len(throughputs) - num_blocks + 1)]
+    best_start = min(options)[1]
+    return list(range(best_start, best_start + num_blocks))

+ 63 - 16
src/server/server.py

@@ -13,8 +13,10 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src import declare_active_modules, BloomConfig
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from src.dht_utils import get_remote_module_infos
 from src.server.backend import TransformerBackend
+from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 
@@ -32,19 +34,26 @@ class Server(threading.Thread):
         *,
         device: torch.device,
         num_connection_handlers: int = 8,
+        throughput: float,
         update_period: float = 30,
         expiration: Optional[float] = None,
         start: bool,
         **kwargs,
     ):
         threading.Thread.__init__(self)
-        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
+        self.dht, self.module_backends = dht, module_backends
+        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
-            self.module_backends, dht, update_period, expiration, daemon=True
+            self.module_backends,
+            dht,
+            throughput=throughput,
+            update_period=update_period,
+            expiration=expiration,
+            daemon=True,
         )
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
@@ -86,6 +95,7 @@ class Server(threading.Thread):
         cls,
         prefix: Optional[str],
         converted_model_name_or_path: str,
+        throughput: float,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
@@ -116,6 +126,9 @@ class Server(threading.Thread):
             )
             logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@@ -127,6 +140,10 @@ class Server(threading.Thread):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
+        block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token
+        )
+
         if block_indices is not None:
             try:
                 first_block_index, last_block_index = block_indices.split(":")
@@ -137,16 +154,22 @@ class Server(threading.Thread):
             block_indices = range(first_block_index, last_block_index)
         else:
             assert num_blocks is not None
-            block_indices = range(num_blocks)  # TODO replace with proper load balancing
+            uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
+            module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
+            block_indices = choose_best_blocks(num_blocks, module_infos)
 
-        block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        declare_active_modules(
+            dht,
+            module_uids,
+            expiration_time=get_dht_time() + expiration,
+            state=ServerState.JOINING,
+            throughput=throughput,
         )
+        logger.info(f"Announced that blocks {block_indices} are joining")
 
-        # initialize modules
         blocks = {}
-        for block_index in block_indices:
-            module_uid = f"{prefix}.{block_index}"
+        for module_uid, block_index in zip(module_uids, block_indices):
             block = load_pretrained_block(
                 converted_model_name_or_path,
                 block_index,
@@ -173,6 +196,7 @@ class Server(threading.Thread):
         return cls(
             dht,
             blocks,
+            throughput=throughput,
             num_connection_handlers=num_handlers,
             device=device,
             stats_report_interval=stats_report_interval,
@@ -209,6 +233,16 @@ class Server(threading.Thread):
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
+        if self.module_backends:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.OFFLINE,
+                throughput=self.throughput,
+            )
+            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+
         self.ready.clear()
 
         for process in self.conn_handlers:
@@ -230,25 +264,38 @@ class Server(threading.Thread):
         logger.debug(f"Shutting down runtime")
 
         self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
+        logger.info("Server shut down succesfully")
 
 
 class ModuleAnnouncerThread(threading.Thread):
     """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
 
     def __init__(
-        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+        self,
+        module_backends: Dict[str, TransformerBackend],
+        dht: DHT,
+        *,
+        throughput: float,
+        update_period: float = 30,
+        expiration: float,
+        **kwargs
     ):
         super().__init__(**kwargs)
-        if expiration is None:
-            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.module_backends = module_backends
         self.dht = dht
+        self.throughput = throughput
         self.update_period = update_period
         self.expiration = expiration
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
-        while not self.stop.wait(self.update_period):
-            declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
+        while True:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.ONLINE,
+                throughput=self.throughput,
+            )
+            if self.stop.wait(self.update_period):
+                break