瀏覽代碼

Make a server ping next servers (#356)

This PR makes a server ping potential next servers in a chain and report the RTTs to DHT. This will be used for shortest-path routing.
Alexander Borzunov 2 年之前
父節點
當前提交
81c4a45ca2
共有 4 個文件被更改,包括 120 次插入53 次删除
  1. 1 1
      src/petals/cli/run_server.py
  2. 1 0
      src/petals/data_structures.py
  3. 58 52
      src/petals/server/server.py
  4. 60 0
      src/petals/utils/ping.py

+ 1 - 1
src/petals/cli/run_server.py

@@ -98,7 +98,7 @@ def main():
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'on the first run and uses these estimates for future runs. '
                              'on the first run and uses these estimates for future runs. '
                              'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
                              'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
-    parser.add_argument('--update_period', type=float, required=False, default=150,
+    parser.add_argument('--update_period', type=float, required=False, default=60,
                         help='Server will report blocks to DHT once in this many seconds')
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
                         help='DHT entries will expire after this many seconds')

+ 1 - 0
src/petals/data_structures.py

@@ -30,6 +30,7 @@ class ServerInfo:
     quant_type: Optional[str] = None
     quant_type: Optional[str] = None
     using_relay: Optional[bool] = None
     using_relay: Optional[bool] = None
     cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
     cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
+    next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
 
 
     def to_tuple(self) -> Tuple[int, float, dict]:
     def to_tuple(self) -> Tuple[int, float, dict]:
         extra_info = dataclasses.asdict(self)
         extra_info = dataclasses.asdict(self)

+ 58 - 52
src/petals/server/server.py

@@ -8,6 +8,7 @@ import threading
 import time
 import time
 from typing import Dict, List, Optional, Sequence, Union
 from typing import Dict, List, Optional, Sequence, Union
 
 
+import hivemind
 import torch
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -30,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
+from petals.utils.ping import PingAggregator
 from petals.utils.version import get_compatible_model_repo
 from petals.utils.version import get_compatible_model_repo
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -64,7 +66,7 @@ class Server:
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
         custom_module_path=None,
-        update_period: float = 150,
+        update_period: float = 60,
         expiration: Optional[float] = None,
         expiration: Optional[float] = None,
         request_timeout: float = 3 * 60,
         request_timeout: float = 3 * 60,
         session_timeout: float = 30 * 60,
         session_timeout: float = 30 * 60,
@@ -220,7 +222,7 @@ class Server:
             throughput=throughput,
             throughput=throughput,
             adapters=tuple(adapters),
             adapters=tuple(adapters),
             version=petals.__version__,
             version=petals.__version__,
-            torch_dtype=str(torch_dtype).lstrip("torch."),
+            torch_dtype=str(torch_dtype).replace("torch.", ""),
             quant_type=quant_type.name.lower(),
             quant_type=quant_type.name.lower(),
             using_relay=self.dht.client_mode,
             using_relay=self.dht.client_mode,
         )
         )
@@ -413,8 +415,8 @@ class ModuleContainer(threading.Thread):
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
         memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
 
 
-        server_info.state = ServerState.JOINING
-        joining_announcer = ModuleAnnouncerThread(
+        assert server_info.state == ServerState.JOINING
+        dht_announcer = ModuleAnnouncerThread(
             module_uids,
             module_uids,
             dht,
             dht,
             server_info,
             server_info,
@@ -424,7 +426,7 @@ class ModuleContainer(threading.Thread):
             expiration=expiration,
             expiration=expiration,
             daemon=True,
             daemon=True,
         )
         )
-        joining_announcer.start()
+        dht_announcer.start()
         logger.info(f"Announced that blocks {block_indices} are joining")
         logger.info(f"Announced that blocks {block_indices} are joining")
 
 
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
@@ -476,6 +478,8 @@ class ModuleContainer(threading.Thread):
                     max_batch_size=max_batch_size,
                     max_batch_size=max_batch_size,
                 )
                 )
 
 
+            merge_inference_pools_inplace(blocks)
+
             if should_validate_reachability:
             if should_validate_reachability:
                 validate_reachability(dht.peer_id)
                 validate_reachability(dht.peer_id)
         except:
         except:
@@ -483,29 +487,15 @@ class ModuleContainer(threading.Thread):
             for backend in blocks.values():
             for backend in blocks.values():
                 backend.shutdown()
                 backend.shutdown()
 
 
-            joining_announcer.stop.set()
-            joining_announcer.join()
-            server_info.state = ServerState.OFFLINE
-            declare_active_modules(
-                dht,
-                module_uids,
-                server_info,
-                expiration_time=get_dht_time() + expiration,
-            )
+            dht_announcer.announce(ServerState.OFFLINE)
             logger.info(f"Announced that blocks {module_uids} are offline")
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
             raise
-        else:
-            joining_announcer.stop.set()
-            joining_announcer.join()
-
-        merge_inference_pools_inplace(blocks)
 
 
         return cls(
         return cls(
             dht,
             dht,
             dht_prefix,
             dht_prefix,
             blocks,
             blocks,
-            block_config=block_config,
-            memory_cache=memory_cache,
+            dht_announcer=dht_announcer,
             server_info=server_info,
             server_info=server_info,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
@@ -518,10 +508,9 @@ class ModuleContainer(threading.Thread):
         dht_prefix: str,
         dht_prefix: str,
         module_backends: Dict[str, TransformerBackend],
         module_backends: Dict[str, TransformerBackend],
         *,
         *,
-        block_config: PretrainedConfig,
-        memory_cache: MemoryCache,
         inference_max_length: int,
         inference_max_length: int,
         num_handlers: int,
         num_handlers: int,
+        dht_announcer: ModuleAnnouncerThread,
         server_info: ServerInfo,
         server_info: ServerInfo,
         update_period: float,
         update_period: float,
         expiration: Optional[float] = None,
         expiration: Optional[float] = None,
@@ -558,17 +547,8 @@ class ModuleContainer(threading.Thread):
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
 
 
-        self.server_info.state = ServerState.ONLINE
-        self.online_announcer = ModuleAnnouncerThread(
-            list(self.module_backends.keys()),
-            dht,
-            self.server_info,
-            block_config=block_config,
-            memory_cache=memory_cache,
-            update_period=update_period,
-            expiration=expiration,
-            daemon=True,
-        )
+        dht_announcer.announce(ServerState.ONLINE)
+        self.dht_announcer = dht_announcer
 
 
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
@@ -578,11 +558,6 @@ class ModuleContainer(threading.Thread):
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
         Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         runs Runtime (self.runtime) to process incoming requests.
         """
         """
-        if not self.dht.is_alive():
-            self.dht.run_in_background(await_ready=True)
-
-        self.online_announcer.start()
-
         for handler in self.conn_handlers:
         for handler in self.conn_handlers:
             handler.run_in_background()
             handler.run_in_background()
 
 
@@ -621,16 +596,7 @@ class ModuleContainer(threading.Thread):
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
-        self.online_announcer.stop.set()
-        self.online_announcer.join()
-
-        self.server_info.state = ServerState.OFFLINE
-        declare_active_modules(
-            self.dht,
-            self.module_backends.keys(),
-            self.server_info,
-            expiration_time=get_dht_time() + self.expiration,
-        )
+        self.dht_announcer.announce(ServerState.OFFLINE)
         logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
         logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
 
         self.ready.clear()
         self.ready.clear()
@@ -666,8 +632,10 @@ class ModuleAnnouncerThread(threading.Thread):
         *,
         *,
         block_config: PretrainedConfig,
         block_config: PretrainedConfig,
         memory_cache: MemoryCache,
         memory_cache: MemoryCache,
-        update_period: float = 30,
+        update_period: float,
         expiration: float,
         expiration: float,
+        max_pinged: int = 5,
+        max_reported: int = 10,
         **kwargs,
         **kwargs,
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
@@ -678,20 +646,58 @@ class ModuleAnnouncerThread(threading.Thread):
         self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
         self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
-        self.stop = threading.Event()
+        self.trigger = threading.Event()
+
+        self.max_pinged, self.max_reported = max_pinged, max_reported
+        last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
+        dht_prefix, block_index = last_uid.split(UID_DELIMITER)
+        self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
+        self.ping_aggregator = PingAggregator(self.dht)
 
 
     def run(self) -> None:
     def run(self) -> None:
         while True:
         while True:
+            start_time = time.perf_counter()
+
             self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
             self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
+            if self.server_info.state != ServerState.OFFLINE:
+                self._ping_next_servers()
+                self.server_info.next_pings = {
+                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
+                }
+            else:
+                self.server_info.next_pings = None  # No need to ping if we're disconnecting
+
             declare_active_modules(
             declare_active_modules(
                 self.dht,
                 self.dht,
                 self.module_uids,
                 self.module_uids,
                 self.server_info,
                 self.server_info,
                 expiration_time=get_dht_time() + self.expiration,
                 expiration_time=get_dht_time() + self.expiration,
             )
             )
-            if self.stop.wait(self.update_period):
+            if self.server_info.state == ServerState.OFFLINE:
                 break
                 break
 
 
+            delay = self.update_period - (time.perf_counter() - start_time)
+            if delay < 0:
+                logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
+            self.trigger.wait(max(delay, 0))
+            self.trigger.clear()
+
+    def announce(self, state: ServerState) -> None:
+        self.server_info.state = state
+        self.trigger.set()
+        if state == ServerState.OFFLINE:
+            self.join()
+
+    def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
+        [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
+        if module_info is None:
+            return
+
+        next_servers = list(module_info.servers)
+        if len(next_servers) > self.max_pinged:
+            next_servers = random.sample(next_servers, self.max_pinged)
+        self.ping_aggregator.ping(next_servers)
+
 
 
 class RuntimeWithDeduplicatedPools(Runtime):
 class RuntimeWithDeduplicatedPools(Runtime):
     """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
     """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""

+ 60 - 0
src/petals/utils/ping.py

@@ -0,0 +1,60 @@
+import asyncio
+import math
+import time
+from functools import partial
+from typing import Dict, Sequence
+
+import hivemind
+from hivemind.proto import dht_pb2
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+async def ping(
+    peer_id: hivemind.PeerID,
+    _dht: hivemind.DHT,
+    node: hivemind.dht.DHTNode,
+    *,
+    wait_timeout: float = 1,
+) -> float:
+    try:
+        ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
+        start_time = time.perf_counter()
+        await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
+        return time.perf_counter() - start_time
+    except Exception:
+        logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
+        return math.inf
+
+
+async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
+    rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
+    return dict(zip(peer_ids, rpc_infos))
+
+
+class PingAggregator:
+    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600):
+        self.dht = dht
+        self.ema_alpha = ema_alpha
+        self.expiration = expiration
+        self.ping_emas = hivemind.TimedStorage()
+
+    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs):
+        current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
+        logger.debug(f"Current RTTs: {current_rtts}")
+
+        expiration = hivemind.get_dht_time() + self.expiration
+        for peer_id, rtt in current_rtts.items():
+            prev_rtt = self.ping_emas.get(peer_id)
+            if prev_rtt is not None and prev_rtt.value != math.inf:
+                rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing
+            self.ping_emas.store(peer_id, rtt, expiration)
+
+    def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]:
+        with self.ping_emas.freeze():
+            smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
+        logger.debug(f"Smothed RTTs: {smoothed_rtts}")
+
+        fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers]
+        return dict(fastest_rtts)