|
@@ -8,7 +8,8 @@ import time
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
-from hivemind import DHT, P2P, MSGPackSerializer
|
|
|
+from hivemind import DHT, P2P, MSGPackSerializer, PeerID
|
|
|
+from hivemind.dht.node import Blacklist
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
@@ -39,6 +40,7 @@ class RemoteSequenceManager:
|
|
|
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
|
|
|
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
|
|
|
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
|
|
|
+ :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
:param start: start the background thread (see the note below). If false, you will need to start it manually.
|
|
|
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
|
|
running redundant sequence managers for the same set of layers.
|
|
@@ -53,18 +55,21 @@ class RemoteSequenceManager:
|
|
|
update_period: float = 30,
|
|
|
request_timeout: float = 30,
|
|
|
min_backoff: float = 1,
|
|
|
+ ban_timeout: float = 60,
|
|
|
sequence_info: Optional[RemoteSequenceInfo] = None,
|
|
|
rpc_info: Optional[dict] = None,
|
|
|
+ banned_peers: Optional[Blacklist] = None,
|
|
|
*, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
|
|
start: bool,
|
|
|
):
|
|
|
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
|
|
self.dht, self.p2p = dht, p2p
|
|
|
- self.request_timeout, self.min_backoff = request_timeout, min_backoff
|
|
|
+ self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff
|
|
|
self.lock_changes = threading.Lock()
|
|
|
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
self._rpc_info = rpc_info
|
|
|
+ self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
|
|
|
|
|
|
if sequence_info is None:
|
|
|
self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
|
@@ -97,7 +102,7 @@ class RemoteSequenceManager:
|
|
|
logger.error("Using a sequence manager that is not running: it has either crashed or never started")
|
|
|
if not self.ready.is_set():
|
|
|
logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
|
|
|
- self.ready.wait()
|
|
|
+ self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(self)
|
|
|
span_sequence = []
|
|
@@ -123,9 +128,11 @@ class RemoteSequenceManager:
|
|
|
self.p2p,
|
|
|
update_period=self._thread.update_period,
|
|
|
request_timeout=self.request_timeout,
|
|
|
+ ban_timeout=self.ban_timeout,
|
|
|
min_backoff=self.min_backoff,
|
|
|
sequence_info=self.sequence_info[ix],
|
|
|
rpc_info=self._rpc_info,
|
|
|
+ banned_peers=self.banned_peers,
|
|
|
start=True,
|
|
|
)
|
|
|
|
|
@@ -143,6 +150,14 @@ class RemoteSequenceManager:
|
|
|
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
|
self.dht, self.block_uids, expiration_time=float("inf")
|
|
|
)
|
|
|
+ for block_info in new_block_infos:
|
|
|
+ if not block_info:
|
|
|
+ continue
|
|
|
+ for peer_id in tuple(block_info.servers.keys()):
|
|
|
+ if peer_id in self.banned_peers:
|
|
|
+ logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}")
|
|
|
+ block_info.servers.pop(peer_id)
|
|
|
+
|
|
|
with self.lock_changes:
|
|
|
self.sequence_info.update_(new_block_infos)
|
|
|
missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
|
|
@@ -158,6 +173,24 @@ class RemoteSequenceManager:
|
|
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
time.sleep(delay)
|
|
|
|
|
|
+ def on_request_failure(self, peer_id: PeerID):
|
|
|
+ """remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
|
|
|
+ logger.info(f"Peer {peer_id} did not respond, banning it temporarily")
|
|
|
+ self.banned_peers.register_failure(peer_id)
|
|
|
+ with self.lock_changes:
|
|
|
+ should_update = False
|
|
|
+ for info in self.sequence_info.block_infos:
|
|
|
+ info.servers.pop(peer_id, None)
|
|
|
+ if not info.servers:
|
|
|
+ should_update = True
|
|
|
+ if should_update:
|
|
|
+ self.ready.clear()
|
|
|
+ self.update(wait=False)
|
|
|
+
|
|
|
+ def on_request_success(self, peer_id: PeerID):
|
|
|
+ """if peer has a failure streak, clear that streak"""
|
|
|
+ self.banned_peers.register_success(peer_id)
|
|
|
+
|
|
|
def __len__(self):
|
|
|
return len(self.block_uids)
|
|
|
|
|
@@ -178,16 +211,21 @@ class RemoteSequenceManager:
|
|
|
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
|
|
if self._rpc_info is None:
|
|
|
for attempt_no in itertools.count():
|
|
|
+ peer_id = None
|
|
|
try:
|
|
|
- self._update()
|
|
|
+ if not self.ready.is_set():
|
|
|
+ self.update(wait=True)
|
|
|
peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
|
|
|
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
|
|
|
outputs = RemoteExpertWorker.run_coroutine(
|
|
|
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
|
|
|
)
|
|
|
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
+ self.on_request_success(peer_id)
|
|
|
break
|
|
|
except Exception as e:
|
|
|
+ if peer_id is not None:
|
|
|
+ self.on_request_failure(peer_id)
|
|
|
delay = self.get_retry_delay(attempt_no)
|
|
|
logger.warning(
|
|
|
f"Caught exception when gathering information from peer {peer_id} "
|