|
@@ -10,7 +10,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
import numpy as np
|
|
|
-from hivemind import DHT, P2P, MSGPackSerializer, PeerID
|
|
|
+from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
|
|
|
from hivemind.dht.node import Blacklist
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import P2PHandlerError
|
|
@@ -66,8 +66,7 @@ class RemoteSequenceManager:
|
|
|
rpc_info: Optional[dict] = None,
|
|
|
allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
|
|
|
banned_peers: Optional[Blacklist] = None,
|
|
|
- *, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
|
|
- start: bool,
|
|
|
+ # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
|
|
):
|
|
|
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
|
|
self.dht, self.p2p = dht, p2p
|
|
@@ -75,6 +74,7 @@ class RemoteSequenceManager:
|
|
|
self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff
|
|
|
self.lock_changes = threading.Lock()
|
|
|
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
|
|
+ self._thread_start_lock = threading.Lock()
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
self._rpc_info = rpc_info
|
|
|
|
|
@@ -87,23 +87,16 @@ class RemoteSequenceManager:
|
|
|
|
|
|
if sequence_info is None:
|
|
|
self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
|
|
- self.update(wait=False)
|
|
|
+
|
|
|
+ # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
|
|
|
+ # in the first _update() instead of the latest ones. This makes the first .update() faster.
|
|
|
+ petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
|
|
|
+ self._need_latest_infos = False
|
|
|
else:
|
|
|
self.sequence_info = sequence_info
|
|
|
assert block_uids == sequence_info.block_uids
|
|
|
self._thread.ready.set() # no need to await the first dht fetch
|
|
|
-
|
|
|
- if start:
|
|
|
- self.run_in_background()
|
|
|
-
|
|
|
- def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
|
|
- """
|
|
|
- Starts the updater thread in a background. if await_ready, this method will wait until sequence manager
|
|
|
- is ready to process incoming requests or for :timeout: seconds max.
|
|
|
- """
|
|
|
- self._thread.start()
|
|
|
- if await_ready:
|
|
|
- self._thread.ready.wait(timeout)
|
|
|
+ self._need_latest_infos = True
|
|
|
|
|
|
def make_sequence(
|
|
|
self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
|
|
@@ -115,10 +108,10 @@ class RemoteSequenceManager:
|
|
|
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
|
|
:param mode: either random or fastest
|
|
|
"""
|
|
|
- if not self.is_alive():
|
|
|
- logger.error("Using a sequence manager that is not running: it has either crashed or never started")
|
|
|
+ with self._thread_start_lock:
|
|
|
+ if not self.is_alive():
|
|
|
+ self._thread.start()
|
|
|
if not self.ready.is_set():
|
|
|
- logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
|
|
|
self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(self)
|
|
@@ -163,7 +156,6 @@ class RemoteSequenceManager:
|
|
|
rpc_info=self._rpc_info,
|
|
|
allowed_servers=self.allowed_servers,
|
|
|
banned_peers=self.banned_peers,
|
|
|
- start=True,
|
|
|
)
|
|
|
|
|
|
def update(self, *, wait: bool):
|
|
@@ -178,8 +170,10 @@ class RemoteSequenceManager:
|
|
|
for attempt_no in itertools.count():
|
|
|
try:
|
|
|
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
|
- self.dht, self.block_uids, expiration_time=float("inf")
|
|
|
+ self.dht, self.block_uids, latest=self._need_latest_infos
|
|
|
)
|
|
|
+ self._need_latest_infos = True # All future _update() should use latest infos
|
|
|
+
|
|
|
for block_info in new_block_infos:
|
|
|
if not block_info:
|
|
|
continue
|
|
@@ -259,6 +253,10 @@ class RemoteSequenceManager:
|
|
|
def rpc_info(self):
|
|
|
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
|
|
if self._rpc_info is None:
|
|
|
+ with self._thread_start_lock:
|
|
|
+ if not self.is_alive():
|
|
|
+ self._thread.start()
|
|
|
+
|
|
|
for attempt_no in itertools.count():
|
|
|
peer_id = None
|
|
|
try:
|
|
@@ -320,18 +318,11 @@ class _SequenceManagerUpdateThread(threading.Thread):
|
|
|
self.ref_update_manager = ref_update_manager
|
|
|
self.ready = threading.Event()
|
|
|
self.trigger = threading.Event()
|
|
|
- self.last_update_time = -float("inf")
|
|
|
self.update_period = update_period
|
|
|
self.should_shutdown = False
|
|
|
|
|
|
def run(self) -> None:
|
|
|
while not self.should_shutdown:
|
|
|
- self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time)))
|
|
|
-
|
|
|
- if self.should_shutdown:
|
|
|
- logger.debug(f"{self.__class__.__name__} is shutting down")
|
|
|
- break
|
|
|
-
|
|
|
update_manager = self.ref_update_manager()
|
|
|
if update_manager is None:
|
|
|
logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
|
|
@@ -345,16 +336,18 @@ class _SequenceManagerUpdateThread(threading.Thread):
|
|
|
finally:
|
|
|
del update_manager
|
|
|
|
|
|
+ self.trigger.wait(self.update_period)
|
|
|
+
|
|
|
logger.debug(f"{self.__class__.__name__} thread exited")
|
|
|
|
|
|
def shutdown(self, timeout: Optional[float] = None):
|
|
|
self.should_shutdown = True
|
|
|
self.trigger.set()
|
|
|
- self.join(timeout)
|
|
|
+ if self.is_alive():
|
|
|
+ self.join(timeout)
|
|
|
|
|
|
def __del__(self):
|
|
|
- if self.is_alive():
|
|
|
- self.shutdown()
|
|
|
+ self.shutdown()
|
|
|
|
|
|
|
|
|
def maybe_log_traceback(exc: Exception):
|