Jelajahi Sumber

Start SequenceManager's thread only after first .make_sequence() (#301)

**Why?**

- We'd like to avoid excess threads for the original sequence manager in case if we only use its slices (e.g. when we add adapters or need only a subset of model blocks):

- If we create a sequence manager just before a fork (e.g. in a web app backend or a multi-thread benchmark), we'd like to avoid excess threads in the original process and only use this thread in child processes where we actually call `.make_sequence()`.
Alexander Borzunov 2 tahun lalu
induk
melakukan
21c3526ec1

+ 0 - 1
src/petals/client/remote_sequential.py

@@ -48,7 +48,6 @@ class RemoteSequential(nn.Module):
                 request_timeout=config.request_timeout,
                 max_retries=config.max_retries,
                 allowed_servers=config.allowed_servers,
-                start=True,
                 **kwargs,
             )
             self.is_subsequence = False

+ 24 - 31
src/petals/client/routing/sequence_manager.py

@@ -10,7 +10,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 
 import numpy as np
-from hivemind import DHT, P2P, MSGPackSerializer, PeerID
+from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
 from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import P2PHandlerError
@@ -66,8 +66,7 @@ class RemoteSequenceManager:
         rpc_info: Optional[dict] = None,
         allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
         banned_peers: Optional[Blacklist] = None,
-        *,  # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
-        start: bool,
+        # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         self.dht, self.p2p = dht, p2p
@@ -75,6 +74,7 @@ class RemoteSequenceManager:
         self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff
         self.lock_changes = threading.Lock()
         self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
+        self._thread_start_lock = threading.Lock()
         self.policy = NoSpendingPolicy()
         self._rpc_info = rpc_info
 
@@ -87,23 +87,16 @@ class RemoteSequenceManager:
 
         if sequence_info is None:
             self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
-            self.update(wait=False)
+
+            # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
+            # in the first _update() instead of the latest ones. This makes the first .update() faster.
+            petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
+            self._need_latest_infos = False
         else:
             self.sequence_info = sequence_info
             assert block_uids == sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
-
-        if start:
-            self.run_in_background()
-
-    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
-        """
-        Starts the updater thread in a background. if await_ready, this method will wait until sequence manager
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self._thread.start()
-        if await_ready:
-            self._thread.ready.wait(timeout)
+            self._need_latest_infos = True
 
     def make_sequence(
         self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
@@ -115,10 +108,10 @@ class RemoteSequenceManager:
         :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
         :param mode: either random or fastest
         """
-        if not self.is_alive():
-            logger.error("Using a sequence manager that is not running: it has either crashed or never started")
+        with self._thread_start_lock:
+            if not self.is_alive():
+                self._thread.start()
         if not self.ready.is_set():
-            logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
             self.update(wait=True)  # this will await an existing update or trigger a new one (if not updating)
 
         end_index = end_index if end_index is not None else len(self)
@@ -163,7 +156,6 @@ class RemoteSequenceManager:
             rpc_info=self._rpc_info,
             allowed_servers=self.allowed_servers,
             banned_peers=self.banned_peers,
-            start=True,
         )
 
     def update(self, *, wait: bool):
@@ -178,8 +170,10 @@ class RemoteSequenceManager:
         for attempt_no in itertools.count():
             try:
                 new_block_infos = petals.dht_utils.get_remote_module_infos(
-                    self.dht, self.block_uids, expiration_time=float("inf")
+                    self.dht, self.block_uids, latest=self._need_latest_infos
                 )
+                self._need_latest_infos = True  # All future _update() should use latest infos
+
                 for block_info in new_block_infos:
                     if not block_info:
                         continue
@@ -259,6 +253,10 @@ class RemoteSequenceManager:
     def rpc_info(self):
         """Return the rpc_info queried from one of the servers that hold the first block"""
         if self._rpc_info is None:
+            with self._thread_start_lock:
+                if not self.is_alive():
+                    self._thread.start()
+
             for attempt_no in itertools.count():
                 peer_id = None
                 try:
@@ -320,18 +318,11 @@ class _SequenceManagerUpdateThread(threading.Thread):
         self.ref_update_manager = ref_update_manager
         self.ready = threading.Event()
         self.trigger = threading.Event()
-        self.last_update_time = -float("inf")
         self.update_period = update_period
         self.should_shutdown = False
 
     def run(self) -> None:
         while not self.should_shutdown:
-            self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time)))
-
-            if self.should_shutdown:
-                logger.debug(f"{self.__class__.__name__} is shutting down")
-                break
-
             update_manager = self.ref_update_manager()
             if update_manager is None:
                 logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
@@ -345,16 +336,18 @@ class _SequenceManagerUpdateThread(threading.Thread):
             finally:
                 del update_manager
 
+            self.trigger.wait(self.update_period)
+
         logger.debug(f"{self.__class__.__name__} thread exited")
 
     def shutdown(self, timeout: Optional[float] = None):
         self.should_shutdown = True
         self.trigger.set()
-        self.join(timeout)
+        if self.is_alive():
+            self.join(timeout)
 
     def __del__(self):
-        if self.is_alive():
-            self.shutdown()
+        self.shutdown()
 
 
 def maybe_log_traceback(exc: Exception):

+ 17 - 12
src/petals/dht_utils.py

@@ -93,7 +93,7 @@ async def _get_remote_sequence(
 ) -> petals.client.RemoteSequential:
     uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
     p2p = await dht.replicate_p2p()
-    manager = petals.client.RemoteSequenceManager(dht, uids, p2p, start=True)
+    manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
     return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
 
 
@@ -124,7 +124,7 @@ async def _get_remote_module(
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()
-    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p, start=True) for uid in uids)
+    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
     modules = [
         petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
         for m in managers
@@ -133,21 +133,26 @@ async def _get_remote_module(
 
 
 def get_remote_module_infos(
-    dht: DHT, uid_or_uids: Union[ModuleUID, Sequence[ModuleUID]], expiration_time: Optional[DHTExpiration] = None
-) -> List[Optional[RemoteModuleInfo]]:
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time),
-        return_future=False,
+    dht: DHT,
+    uids: Sequence[ModuleUID],
+    expiration_time: Optional[DHTExpiration] = None,
+    *,
+    latest: bool = False,
+    return_future: bool = False,
+) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
+    return dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
+        return_future=return_future,
     )
-    return infos[0] if single_uid else infos
 
 
 async def _get_remote_module_infos(
-    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
+    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
 ) -> List[Optional[RemoteModuleInfo]]:
-    if expiration_time is None:
+    if latest:
+        assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
+        expiration_time = math.inf
+    elif expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)

+ 2 - 2
src/petals/server/server.py

@@ -324,14 +324,14 @@ class Server:
         # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
         # this delay decreases the probability of a race condition while choosing the best blocks to serve.
         time.sleep(random.random() * 2 * self.mean_block_selection_delay)
-        module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+        module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
         return block_selection.choose_best_blocks(self.num_blocks, module_infos)
 
     def _should_choose_other_blocks(self) -> bool:
         if self.strict_block_indices is not None:
             return False
 
-        module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+        module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
         return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
 
     def shutdown(self):

+ 1 - 1
tests/test_remote_sequential.py

@@ -48,7 +48,7 @@ def test_remote_sequential():
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     lossy_sequential = RemoteSequential(
-        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True)
+        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
     )
 
     test_inputs.grad = None

+ 1 - 1
tests/test_sequence_manager.py

@@ -26,7 +26,7 @@ def test_sequence_manager_basics(mode: str):
     sequential = RemoteSequential(
         config,
         dht,
-        sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
+        sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt),
     )
 
     sequence = sequential.sequence_manager.make_sequence(mode=mode)