|
@@ -1,6 +1,7 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
+import dataclasses
|
|
|
import itertools
|
|
|
import logging
|
|
|
import random
|
|
@@ -13,7 +14,6 @@ import numpy as np
|
|
|
from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
|
|
|
from hivemind.dht.node import Blacklist
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
-from hivemind.p2p import P2PHandlerError
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
@@ -26,6 +26,33 @@ from petals.server.handler import TransformerConnectionHandler
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
+@dataclasses.dataclass
|
|
|
+class SequenceManagerConfig:
|
|
|
+ allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
|
|
+
|
|
|
+ request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
|
|
|
+ update_period: float = 60 # refresh DHT information once in this many seconds
|
|
|
+
|
|
|
+ max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
|
|
|
+ min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
|
|
+ max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
|
|
+ ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
+
|
|
|
+
|
|
|
+@dataclasses.dataclass
|
|
|
+class SequenceManagerState:
|
|
|
+ p2p: P2P = None
|
|
|
+ sequence_info: Optional[RemoteSequenceInfo] = None
|
|
|
+ rpc_info: Optional[dict] = None
|
|
|
+ banned_peers: Optional[Blacklist] = None
|
|
|
+
|
|
|
+ def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:
|
|
|
+ return dataclasses.replace(self, sequence_info=self.sequence_info[ix])
|
|
|
+
|
|
|
+ def __len__(self) -> int:
|
|
|
+ return len(self.sequence_info)
|
|
|
+
|
|
|
+
|
|
|
class RemoteSequenceManager:
|
|
|
"""
|
|
|
Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
|
|
@@ -34,67 +61,56 @@ class RemoteSequenceManager:
|
|
|
Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
|
|
|
To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
|
|
|
|
|
|
- :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
|
|
|
- :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
|
|
|
- :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
|
|
|
- :param update_period: by default, refresh DHT information once in this many seconds
|
|
|
- :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
|
|
|
- :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
|
|
|
- :param max_backoff: limit maximal sleep time between retries to this value
|
|
|
- :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
- :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
|
|
|
- :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
|
|
|
- :param allowed_servers: if defined, send requests only to these servers
|
|
|
- :param start: start the background thread (see the note below). If false, you will need to start it manually.
|
|
|
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
|
|
running redundant sequence managers for the same set of layers.
|
|
|
-
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- dht: DHT,
|
|
|
+ config: SequenceManagerConfig,
|
|
|
block_uids: Sequence[ModuleUID],
|
|
|
- p2p: P2P,
|
|
|
- update_period: float = 30,
|
|
|
- request_timeout: float = 30,
|
|
|
- max_retries: Optional[int] = None,
|
|
|
- min_backoff: float = 1,
|
|
|
- max_backoff: float = 15 * 60,
|
|
|
- ban_timeout: float = 15,
|
|
|
- sequence_info: Optional[RemoteSequenceInfo] = None,
|
|
|
- rpc_info: Optional[dict] = None,
|
|
|
- allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
|
|
|
- banned_peers: Optional[Blacklist] = None,
|
|
|
- # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
|
|
+ *,
|
|
|
+ dht: Optional[DHT] = None,
|
|
|
+ state: Optional[SequenceManagerState] = None,
|
|
|
):
|
|
|
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
|
|
- self.dht, self.p2p = dht, p2p
|
|
|
- self.request_timeout, self.max_retries = request_timeout, max_retries
|
|
|
- self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ if state is None:
|
|
|
+ state = SequenceManagerState()
|
|
|
+ self.state = state
|
|
|
+
|
|
|
+ if dht is None:
|
|
|
+ dht = DHT(
|
|
|
+ initial_peers=config.initial_peers,
|
|
|
+ client_mode=True,
|
|
|
+ num_workers=config.n_layer,
|
|
|
+ startup_timeout=config.daemon_startup_timeout,
|
|
|
+ start=True,
|
|
|
+ )
|
|
|
+ assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
|
|
|
+ self.dht = dht
|
|
|
+
|
|
|
+ if state.p2p is None:
|
|
|
+ state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
+
|
|
|
self.lock_changes = threading.Lock()
|
|
|
- self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
|
|
+ self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update))
|
|
|
self._thread_start_lock = threading.Lock()
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
- self._rpc_info = rpc_info
|
|
|
|
|
|
- if allowed_servers is not None:
|
|
|
- allowed_servers = {
|
|
|
- PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers
|
|
|
- }
|
|
|
- self.allowed_servers = allowed_servers
|
|
|
- self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
|
|
|
-
|
|
|
- if sequence_info is None:
|
|
|
- self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
|
|
+ if state.banned_peers is None:
|
|
|
+ state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
|
|
|
+ if state.sequence_info is None:
|
|
|
+ state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
|
|
|
|
|
+ if state.sequence_info.last_updated_time is None:
|
|
|
# Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
|
|
|
# in the first _update() instead of the latest ones. This makes the first .update() faster.
|
|
|
petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
|
|
|
self._need_latest_infos = False
|
|
|
else:
|
|
|
- self.sequence_info = sequence_info
|
|
|
- assert block_uids == sequence_info.block_uids
|
|
|
+ assert block_uids == state.sequence_info.block_uids
|
|
|
self._thread.ready.set() # no need to await the first dht fetch
|
|
|
self._need_latest_infos = True
|
|
|
|
|
@@ -118,7 +134,7 @@ class RemoteSequenceManager:
|
|
|
span_sequence = []
|
|
|
current_index = start_index
|
|
|
while current_index < end_index:
|
|
|
- candidate_spans = self.sequence_info.spans_containing_block[current_index]
|
|
|
+ candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
|
|
|
if not candidate_spans:
|
|
|
raise MissingBlocksError(current_index)
|
|
|
if mode == "random":
|
|
@@ -143,86 +159,62 @@ class RemoteSequenceManager:
|
|
|
assert isinstance(ix, (int, slice))
|
|
|
if not isinstance(ix, slice):
|
|
|
ix = slice(int(ix), int(ix) + 1, 1)
|
|
|
- return type(self)(
|
|
|
- self.dht,
|
|
|
- self.block_uids[ix],
|
|
|
- self.p2p,
|
|
|
- update_period=self._thread.update_period,
|
|
|
- request_timeout=self.request_timeout,
|
|
|
- ban_timeout=self.ban_timeout,
|
|
|
- min_backoff=self.min_backoff,
|
|
|
- max_backoff=self.max_backoff,
|
|
|
- sequence_info=self.sequence_info[ix],
|
|
|
- rpc_info=self._rpc_info,
|
|
|
- allowed_servers=self.allowed_servers,
|
|
|
- banned_peers=self.banned_peers,
|
|
|
- )
|
|
|
+ return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
|
|
|
|
|
|
def update(self, *, wait: bool):
|
|
|
"""Run an asynchronous update in background as soon as possible"""
|
|
|
- self.ready.clear() # TODO this should be a separate event
|
|
|
+ self.ready.clear()
|
|
|
self._thread.trigger.set()
|
|
|
if wait:
|
|
|
self.ready.wait()
|
|
|
|
|
|
def _update(self):
|
|
|
"""Perform an immediate and synchronous refresh, may take time"""
|
|
|
- for attempt_no in itertools.count():
|
|
|
- try:
|
|
|
- new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
|
- self.dht, self.block_uids, latest=self._need_latest_infos
|
|
|
- )
|
|
|
- self._need_latest_infos = True # All future _update() should use latest infos
|
|
|
-
|
|
|
- for block_info in new_block_infos:
|
|
|
- if not block_info:
|
|
|
- continue
|
|
|
-
|
|
|
- # Apply whitelist, if defined
|
|
|
- if self.allowed_servers is not None:
|
|
|
- block_info.servers = {
|
|
|
- peer_id: server_info
|
|
|
- for peer_id, server_info in block_info.servers.items()
|
|
|
- if peer_id in self.allowed_servers
|
|
|
- }
|
|
|
-
|
|
|
- # Remove temporarily banned peers, unless there are no peers left
|
|
|
- valid_servers = {
|
|
|
- peer_id: server_info
|
|
|
- for peer_id, server_info in block_info.servers.items()
|
|
|
- if peer_id not in self.banned_peers
|
|
|
- }
|
|
|
- if len(valid_servers) < len(block_info.servers):
|
|
|
- if valid_servers:
|
|
|
- logger.debug(
|
|
|
- f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
|
|
|
- )
|
|
|
- block_info.servers = valid_servers
|
|
|
- else:
|
|
|
- # If we blacklisted all servers, the error may actually be client-caused
|
|
|
- logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
|
|
|
-
|
|
|
- with self.lock_changes:
|
|
|
- self.sequence_info.update_(new_block_infos)
|
|
|
- missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
|
|
|
- if missing_blocks:
|
|
|
- raise MissingBlocksError(missing_blocks)
|
|
|
- self.ready.set() # if there is an active server for every block, we may begin running
|
|
|
- break
|
|
|
+ new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
|
+ self.dht, self.block_uids, latest=self._need_latest_infos
|
|
|
+ )
|
|
|
+ self._need_latest_infos = True # All future _update() should use latest infos
|
|
|
+
|
|
|
+ for block_info in new_block_infos:
|
|
|
+ if not block_info:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Apply whitelist, if defined
|
|
|
+ if self.config.allowed_servers is not None:
|
|
|
+ block_info.servers = {
|
|
|
+ peer_id: server_info
|
|
|
+ for peer_id, server_info in block_info.servers.items()
|
|
|
+ if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
|
|
|
+ }
|
|
|
+
|
|
|
+ # Remove temporarily banned peers, unless there are no peers left
|
|
|
+ valid_servers = {
|
|
|
+ peer_id: server_info
|
|
|
+ for peer_id, server_info in block_info.servers.items()
|
|
|
+ if peer_id not in self.state.banned_peers
|
|
|
+ }
|
|
|
+ if len(valid_servers) < len(block_info.servers):
|
|
|
+ if valid_servers:
|
|
|
+ logger.debug(
|
|
|
+ f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
|
|
|
+ )
|
|
|
+ block_info.servers = valid_servers
|
|
|
+ else:
|
|
|
+ # If we blacklisted all servers, the error may actually be client-caused
|
|
|
+ logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
|
|
|
|
|
|
- except Exception as e:
|
|
|
- delay = self.get_retry_delay(attempt_no)
|
|
|
- logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
|
|
|
- maybe_log_traceback(e)
|
|
|
- time.sleep(delay)
|
|
|
+ with self.lock_changes:
|
|
|
+ self.state.sequence_info.update_(new_block_infos)
|
|
|
+ self.ready.set()
|
|
|
|
|
|
- def on_request_failure(self, peer_id: PeerID):
|
|
|
+ def on_request_failure(self, peer_id: Optional[PeerID]):
|
|
|
"""remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
|
|
|
- logger.info(f"Peer {peer_id} did not respond, banning it temporarily")
|
|
|
- self.banned_peers.register_failure(peer_id)
|
|
|
+ if peer_id is not None:
|
|
|
+ logger.debug(f"Peer {peer_id} did not respond, banning it temporarily")
|
|
|
+ self.state.banned_peers.register_failure(peer_id)
|
|
|
with self.lock_changes:
|
|
|
should_update = False
|
|
|
- for info in self.sequence_info.block_infos:
|
|
|
+ for info in self.state.sequence_info.block_infos:
|
|
|
info.servers.pop(peer_id, None)
|
|
|
if not info.servers:
|
|
|
should_update = True
|
|
@@ -232,7 +224,7 @@ class RemoteSequenceManager:
|
|
|
|
|
|
def on_request_success(self, peer_id: PeerID):
|
|
|
"""if peer has a failure streak, clear that streak"""
|
|
|
- self.banned_peers.register_success(peer_id)
|
|
|
+ self.state.banned_peers.register_success(peer_id)
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.block_uids)
|
|
@@ -247,57 +239,58 @@ class RemoteSequenceManager:
|
|
|
|
|
|
@property
|
|
|
def block_uids(self):
|
|
|
- return self.sequence_info.block_uids
|
|
|
+ return self.state.sequence_info.block_uids
|
|
|
|
|
|
@property
|
|
|
def rpc_info(self):
|
|
|
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
|
|
- if self._rpc_info is None:
|
|
|
- with self._thread_start_lock:
|
|
|
- if not self.is_alive():
|
|
|
- self._thread.start()
|
|
|
-
|
|
|
- for attempt_no in itertools.count():
|
|
|
- peer_id = None
|
|
|
- try:
|
|
|
- if not self.ready.is_set():
|
|
|
- self.update(wait=True)
|
|
|
-
|
|
|
- active_servers = [
|
|
|
- peer_id
|
|
|
- for peer_id, server in self.sequence_info.block_infos[0].servers.items()
|
|
|
- if server.state == ServerState.ONLINE
|
|
|
- ]
|
|
|
- if not active_servers:
|
|
|
- raise MissingBlocksError(0)
|
|
|
- peer_id = random.choice(active_servers)
|
|
|
-
|
|
|
- stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(
|
|
|
- stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
|
|
|
- )
|
|
|
- self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
- self.on_request_success(peer_id)
|
|
|
- break
|
|
|
- except Exception as e:
|
|
|
- if peer_id is not None and not isinstance(e, P2PHandlerError):
|
|
|
- self.on_request_failure(peer_id)
|
|
|
- if attempt_no + 1 == self.max_retries:
|
|
|
- raise
|
|
|
- delay = self.get_retry_delay(attempt_no)
|
|
|
- logger.warning(
|
|
|
- f"Caught exception when gathering information from peer {peer_id} "
|
|
|
- f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
- )
|
|
|
- maybe_log_traceback(e)
|
|
|
- time.sleep(delay)
|
|
|
+ if self.state.rpc_info is not None:
|
|
|
+ return self.state.rpc_info
|
|
|
+
|
|
|
+ with self._thread_start_lock:
|
|
|
+ if not self.is_alive():
|
|
|
+ self._thread.start()
|
|
|
+
|
|
|
+ for attempt_no in itertools.count():
|
|
|
+ peer_id = None
|
|
|
+ try:
|
|
|
+ if not self.ready.is_set():
|
|
|
+ self.update(wait=True)
|
|
|
+
|
|
|
+ active_servers = [
|
|
|
+ peer_id
|
|
|
+ for peer_id, server in self.state.sequence_info.block_infos[0].servers.items()
|
|
|
+ if server.state == ServerState.ONLINE
|
|
|
+ ]
|
|
|
+ if not active_servers:
|
|
|
+ raise MissingBlocksError(0)
|
|
|
+ peer_id = random.choice(active_servers)
|
|
|
+
|
|
|
+ stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id)
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout)
|
|
|
+ )
|
|
|
+ self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
+ self.on_request_success(peer_id)
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ self.on_request_failure(peer_id)
|
|
|
+ if attempt_no + 1 == self.config.max_retries:
|
|
|
+ raise
|
|
|
+ delay = self.get_retry_delay(attempt_no)
|
|
|
+ logger.warning(
|
|
|
+ f"Caught exception when gathering information from peer {peer_id} "
|
|
|
+ f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
+ )
|
|
|
+ maybe_log_traceback(e)
|
|
|
+ time.sleep(delay)
|
|
|
|
|
|
- return self._rpc_info
|
|
|
+ return self.state.rpc_info
|
|
|
|
|
|
def get_retry_delay(self, attempt_no: int) -> float:
|
|
|
if attempt_no == 0:
|
|
|
return 0
|
|
|
- return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff)
|
|
|
+ return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
|
|
|
|
|
|
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
|
|
|
"""
|