Преглед изворни кода

Add `allowed_servers`, `max_retries` options to the client, improve logs (#235)

Alexander Borzunov пре 2 година
родитељ
комит
9954cb84fe

+ 3 - 2
src/petals/client/inference_session.py

@@ -307,10 +307,11 @@ class InferenceSession:
                 except Exception as e:
                     if span is not None:
                         self._sequence_manager.on_request_failure(span.peer_id)
+                    if attempt_no + 1 == self._sequence_manager.max_retries:
+                        raise
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(
-                        f"Caught exception when running inference from block {block_idx} "
-                        f"(retry in {delay:.0f} sec): {repr(e)}"
+                        f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}"
                     )
                     maybe_log_traceback(e)
                     time.sleep(delay)

+ 10 - 2
src/petals/client/remote_model.py

@@ -1,6 +1,6 @@
 import os
 from contextlib import contextmanager
-from typing import List, Optional, Union
+from typing import Collection, List, Optional, Union
 
 import hivemind
 import torch
@@ -35,6 +35,10 @@ class DistributedBloomConfig(BloomConfig):
     daemon_startup_timeout: int = 30
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     request_timeout: int = 30  # a number of seconds for waiting result from each node
+    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
+    allowed_servers: Optional[
+        Collection[Union[str, hivemind.PeerID]]
+    ] = None  # if defined, send requests only to these servers
 
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
@@ -112,7 +116,11 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
             )
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
-        self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
+        self.h = RemoteSequential(
+            config,
+            dht,
+            config.dht_prefix,
+        )
 
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)

+ 10 - 1
src/petals/client/remote_sequential.py

@@ -41,7 +41,16 @@ class RemoteSequential(nn.Module):
         block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs)
+            self.sequence_manager = RemoteSequenceManager(
+                dht,
+                block_uids,
+                self.p2p,
+                request_timeout=config.request_timeout,
+                max_retries=config.max_retries,
+                allowed_servers=config.allowed_servers,
+                start=True,
+                **kwargs,
+            )
             self.is_subsequence = False
         else:
             logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")

+ 26 - 3
src/petals/client/routing/sequence_manager.py

@@ -6,7 +6,7 @@ import logging
 import random
 import threading
 import time
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, Collection, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 
 import numpy as np
@@ -40,9 +40,10 @@ class RemoteSequenceManager:
     :param update_period: by default, refresh DHT information once in this many seconds
     :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
     :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
+    :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
     :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
     :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
-    :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
+    :param allowed_servers: if defined, send requests only to these servers
     :param start: start the background thread (see the note below). If false, you will need to start it manually.
     :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
       running redundant sequence managers for the same set of layers.
@@ -56,21 +57,30 @@ class RemoteSequenceManager:
         p2p: P2P,
         update_period: float = 30,
         request_timeout: float = 30,
+        max_retries: Optional[int] = None,
         min_backoff: float = 1,
         ban_timeout: float = 15,
         sequence_info: Optional[RemoteSequenceInfo] = None,
         rpc_info: Optional[dict] = None,
+        allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
         banned_peers: Optional[Blacklist] = None,
         *,  # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
         start: bool,
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         self.dht, self.p2p = dht, p2p
-        self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff
+        self.request_timeout, self.max_retries = request_timeout, max_retries
+        self.ban_timeout, self.min_backoff = ban_timeout, min_backoff
         self.lock_changes = threading.Lock()
         self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
         self.policy = NoSpendingPolicy()
         self._rpc_info = rpc_info
+
+        if allowed_servers is not None:
+            allowed_servers = {
+                PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers
+            }
+        self.allowed_servers = allowed_servers
         self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
 
         if sequence_info is None:
@@ -148,6 +158,7 @@ class RemoteSequenceManager:
             min_backoff=self.min_backoff,
             sequence_info=self.sequence_info[ix],
             rpc_info=self._rpc_info,
+            allowed_servers=self.allowed_servers,
             banned_peers=self.banned_peers,
             start=True,
         )
@@ -169,6 +180,16 @@ class RemoteSequenceManager:
                 for block_info in new_block_infos:
                     if not block_info:
                         continue
+
+                    # Apply whitelist, if defined
+                    if self.allowed_servers is not None:
+                        block_info.servers = {
+                            peer_id: server_info
+                            for peer_id, server_info in block_info.servers.items()
+                            if peer_id in self.allowed_servers
+                        }
+
+                    # Remove temporarily banned peers, unless there are no peers left
                     valid_servers = {
                         peer_id: server_info
                         for peer_id, server_info in block_info.servers.items()
@@ -260,6 +281,8 @@ class RemoteSequenceManager:
                 except Exception as e:
                     if peer_id is not None and not isinstance(e, P2PHandlerError):
                         self.on_request_failure(peer_id)
+                    if attempt_no + 1 == self.max_retries:
+                        raise
                     delay = self.get_retry_delay(attempt_no)
                     logger.warning(
                         f"Caught exception when gathering information from peer {peer_id} "

+ 6 - 4
src/petals/client/sequential_autograd.py

@@ -95,10 +95,11 @@ async def sequential_forward(
             except Exception as e:
                 if span is not None:
                     sequence_manager.on_request_failure(span.peer_id)
+                if attempt_no + 1 == sequence_manager.max_retries:
+                    raise
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
-                    f"Caught exception when running forward from block {block_idx} "
-                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                    f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}"
                 )
                 maybe_log_traceback(e)
                 await asyncio.sleep(delay)
@@ -172,10 +173,11 @@ async def sequential_backward(
             except Exception as e:
                 if span is not None:
                     sequence_manager.on_request_failure(span.peer_id)
+                if attempt_no + 1 == sequence_manager.max_retries:
+                    raise
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
-                    f"Caught exception when running backward between blocks {span.start}-{span.end} "
-                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                    f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}"
                 )
                 maybe_log_traceback(e)
                 await asyncio.sleep(delay)

+ 2 - 1
src/petals/server/block_selection.py

@@ -16,6 +16,7 @@ class Span:
     start: int
     end: int
     throughput: float
+    state: ServerState
 
     @property
     def length(self):
@@ -43,7 +44,7 @@ def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[
                 spans[peer_id].start = min(spans[peer_id].start, block)
                 spans[peer_id].end = max(spans[peer_id].start, block + 1)
             else:
-                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
+                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
 
             throughputs[block] += server.throughput