Explorar o código

Make a server ping next servers (#356)

This PR makes a server ping potential next servers in a chain and report the RTTs to DHT. This will be used for shortest-path routing.
Alexander Borzunov %!s(int64=2) %!d(string=hai) anos
pai
achega
81c4a45ca2

+ 1 - 1
src/petals/cli/run_server.py

@@ -98,7 +98,7 @@ def main():
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'on the first run and uses these estimates for future runs. '
                              'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
-    parser.add_argument('--update_period', type=float, required=False, default=150,
+    parser.add_argument('--update_period', type=float, required=False, default=60,
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')

+ 1 - 0
src/petals/data_structures.py

@@ -30,6 +30,7 @@ class ServerInfo:
     quant_type: Optional[str] = None
     using_relay: Optional[bool] = None
     cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
+    next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
 
     def to_tuple(self) -> Tuple[int, float, dict]:
         extra_info = dataclasses.asdict(self)

+ 58 - 52
src/petals/server/server.py

@@ -8,6 +8,7 @@ import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
 
+import hivemind
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -30,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
+from petals.utils.ping import PingAggregator
 from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
@@ -64,7 +66,7 @@ class Server:
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
-        update_period: float = 150,
+        update_period: float = 60,
         expiration: Optional[float] = None,
         request_timeout: float = 3 * 60,
         session_timeout: float = 30 * 60,
@@ -220,7 +222,7 @@ class Server:
             throughput=throughput,
             adapters=tuple(adapters),
             version=petals.__version__,
-            torch_dtype=str(torch_dtype).lstrip("torch."),
+            torch_dtype=str(torch_dtype).replace("torch.", ""),
             quant_type=quant_type.name.lower(),
             using_relay=self.dht.client_mode,
         )
@@ -413,8 +415,8 @@ class ModuleContainer(threading.Thread):
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
 
-        server_info.state = ServerState.JOINING
-        joining_announcer = ModuleAnnouncerThread(
+        assert server_info.state == ServerState.JOINING
+        dht_announcer = ModuleAnnouncerThread(
             module_uids,
             dht,
             server_info,
@@ -424,7 +426,7 @@ class ModuleContainer(threading.Thread):
             expiration=expiration,
             daemon=True,
         )
-        joining_announcer.start()
+        dht_announcer.start()
         logger.info(f"Announced that blocks {block_indices} are joining")
 
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
@@ -476,6 +478,8 @@ class ModuleContainer(threading.Thread):
                     max_batch_size=max_batch_size,
                 )
 
+            merge_inference_pools_inplace(blocks)
+
             if should_validate_reachability:
                 validate_reachability(dht.peer_id)
         except:
@@ -483,29 +487,15 @@ class ModuleContainer(threading.Thread):
             for backend in blocks.values():
                 backend.shutdown()
 
-            joining_announcer.stop.set()
-            joining_announcer.join()
-            server_info.state = ServerState.OFFLINE
-            declare_active_modules(
-                dht,
-                module_uids,
-                server_info,
-                expiration_time=get_dht_time() + expiration,
-            )
+            dht_announcer.announce(ServerState.OFFLINE)
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
-        else:
-            joining_announcer.stop.set()
-            joining_announcer.join()
-
-        merge_inference_pools_inplace(blocks)
 
         return cls(
             dht,
             dht_prefix,
             blocks,
-            block_config=block_config,
-            memory_cache=memory_cache,
+            dht_announcer=dht_announcer,
             server_info=server_info,
             update_period=update_period,
             expiration=expiration,
@@ -518,10 +508,9 @@ class ModuleContainer(threading.Thread):
         dht_prefix: str,
         module_backends: Dict[str, TransformerBackend],
         *,
-        block_config: PretrainedConfig,
-        memory_cache: MemoryCache,
         inference_max_length: int,
         num_handlers: int,
+        dht_announcer: ModuleAnnouncerThread,
         server_info: ServerInfo,
         update_period: float,
         expiration: Optional[float] = None,
@@ -558,17 +547,8 @@ class ModuleContainer(threading.Thread):
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
 
-        self.server_info.state = ServerState.ONLINE
-        self.online_announcer = ModuleAnnouncerThread(
-            list(self.module_backends.keys()),
-            dht,
-            self.server_info,
-            block_config=block_config,
-            memory_cache=memory_cache,
-            update_period=update_period,
-            expiration=expiration,
-            daemon=True,
-        )
+        dht_announcer.announce(ServerState.ONLINE)
+        self.dht_announcer = dht_announcer
 
         if start:
             self.run_in_background(await_ready=True)
@@ -578,11 +558,6 @@ class ModuleContainer(threading.Thread):
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
-        if not self.dht.is_alive():
-            self.dht.run_in_background(await_ready=True)
-
-        self.online_announcer.start()
-
         for handler in self.conn_handlers:
             handler.run_in_background()
 
@@ -621,16 +596,7 @@ class ModuleContainer(threading.Thread):
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
-        self.online_announcer.stop.set()
-        self.online_announcer.join()
-
-        self.server_info.state = ServerState.OFFLINE
-        declare_active_modules(
-            self.dht,
-            self.module_backends.keys(),
-            self.server_info,
-            expiration_time=get_dht_time() + self.expiration,
-        )
+        self.dht_announcer.announce(ServerState.OFFLINE)
         logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
         self.ready.clear()
@@ -666,8 +632,10 @@ class ModuleAnnouncerThread(threading.Thread):
         *,
         block_config: PretrainedConfig,
         memory_cache: MemoryCache,
-        update_period: float = 30,
+        update_period: float,
         expiration: float,
+        max_pinged: int = 5,
+        max_reported: int = 10,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -678,20 +646,58 @@ class ModuleAnnouncerThread(threading.Thread):
         self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
         self.update_period = update_period
         self.expiration = expiration
-        self.stop = threading.Event()
+        self.trigger = threading.Event()
+
+        self.max_pinged, self.max_reported = max_pinged, max_reported
+        last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
+        dht_prefix, block_index = last_uid.split(UID_DELIMITER)
+        self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
+        self.ping_aggregator = PingAggregator(self.dht)
 
     def run(self) -> None:
         while True:
+            start_time = time.perf_counter()
+
             self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
+            if self.server_info.state != ServerState.OFFLINE:
+                self._ping_next_servers()
+                self.server_info.next_pings = {
+                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
+                }
+            else:
+                self.server_info.next_pings = None  # No need to ping if we're disconnecting
+
             declare_active_modules(
                 self.dht,
                 self.module_uids,
                 self.server_info,
                 expiration_time=get_dht_time() + self.expiration,
             )
-            if self.stop.wait(self.update_period):
+            if self.server_info.state == ServerState.OFFLINE:
                 break
 
+            delay = self.update_period - (time.perf_counter() - start_time)
+            if delay < 0:
+                logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
+            self.trigger.wait(max(delay, 0))
+            self.trigger.clear()
+
+    def announce(self, state: ServerState) -> None:
+        self.server_info.state = state
+        self.trigger.set()
+        if state == ServerState.OFFLINE:
+            self.join()
+
+    def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
+        [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
+        if module_info is None:
+            return
+
+        next_servers = list(module_info.servers)
+        if len(next_servers) > self.max_pinged:
+            next_servers = random.sample(next_servers, self.max_pinged)
+        self.ping_aggregator.ping(next_servers)
+
 
 class RuntimeWithDeduplicatedPools(Runtime):
     """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""

+ 60 - 0
src/petals/utils/ping.py

@@ -0,0 +1,60 @@
+import asyncio
+import math
+import time
+from functools import partial
+from typing import Dict, Sequence
+
+import hivemind
+from hivemind.proto import dht_pb2
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+async def ping(
+    peer_id: hivemind.PeerID,
+    _dht: hivemind.DHT,
+    node: hivemind.dht.DHTNode,
+    *,
+    wait_timeout: float = 1,
+) -> float:
+    try:
+        ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
+        start_time = time.perf_counter()
+        await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
+        return time.perf_counter() - start_time
+    except Exception:
+        logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
+        return math.inf
+
+
+async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
+    rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
+    return dict(zip(peer_ids, rpc_infos))
+
+
+class PingAggregator:
+    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600):
+        self.dht = dht
+        self.ema_alpha = ema_alpha
+        self.expiration = expiration
+        self.ping_emas = hivemind.TimedStorage()
+
+    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs):
+        current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
+        logger.debug(f"Current RTTs: {current_rtts}")
+
+        expiration = hivemind.get_dht_time() + self.expiration
+        for peer_id, rtt in current_rtts.items():
+            prev_rtt = self.ping_emas.get(peer_id)
+            if prev_rtt is not None and prev_rtt.value != math.inf:
+                rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing
+            self.ping_emas.store(peer_id, rtt, expiration)
+
+    def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]:
+        with self.ping_emas.freeze():
+            smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
+        logger.debug(f"Smothed RTTs: {smoothed_rtts}")
+
+        fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers]
+        return dict(fastest_rtts)