Procházet zdrojové kódy

Add `blocked_servers` argument (#462)

Should be used as:

```python
model = AutoDistributedModelForCausalLM(model_name, blocked_servers=[peer_id1, peer_id2])
```
Alexander Borzunov před 2 roky
rodič
revize
329f7d31e8

+ 29 - 8
src/petals/client/routing/sequence_manager.py

@@ -7,7 +7,7 @@ import logging
 import random
 import threading
 import time
-from typing import Any, Collection, Dict, List, Optional, Sequence, Union
+from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union
 from weakref import WeakMethod
 
 import dijkstar
@@ -38,6 +38,7 @@ class SequenceManagerConfig:
 
     show_route: Union[str, bool] = "inference"  # show chosen route through servers. one of [False, "inference", True]
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
+    blocked_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, do not use these servers
     use_server_to_server: bool = True  # Use direct server-to-server communication
 
     connect_timeout: float = 5  # timeout for opening a connection
@@ -116,6 +117,9 @@ class RemoteSequenceManager:
         self._thread_start_lock = threading.Lock()
         self.policy = NoSpendingPolicy()
 
+        self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
+        self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
+
         self.ping_aggregator = PingAggregator(dht)
 
         if state.banned_peers is None:
@@ -128,6 +132,23 @@ class RemoteSequenceManager:
             self._thread.ready.set()  # no need to await the first dht fetch
             self._need_latest_infos = True
 
+    @staticmethod
+    def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
+        if peer_ids is None:
+            return None
+
+        result = set()
+        for peer_id in peer_ids:
+            if isinstance(peer_id, PeerID):
+                result.add(peer_id)
+            elif isinstance(peer_id, str):
+                result.add(PeerID.from_base58(peer_id))
+            else:
+                raise TypeError(
+                    f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
+                )
+        return result
+
     def make_sequence(
         self,
         start_index: int = 0,
@@ -341,13 +362,13 @@ class RemoteSequenceManager:
             if not block_info:
                 continue
 
-            # Apply whitelist, if defined
-            if self.config.allowed_servers is not None:
-                block_info.servers = {
-                    peer_id: server_info
-                    for peer_id, server_info in block_info.servers.items()
-                    if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
-                }
+            # Apply allow and block lists
+            block_info.servers = {
+                peer_id: server_info
+                for peer_id, server_info in block_info.servers.items()
+                if (self.allowed_servers is None or peer_id in self.allowed_servers)
+                and (self.blocked_servers is None or peer_id not in self.blocked_servers)
+            }
 
             # Remove temporarily banned peers, unless there are no peers left
             valid_servers = {

+ 1 - 1
tests/test_remote_sequential.py

@@ -43,7 +43,7 @@ def test_remote_sequential():
     assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
 
     (second_half_outputs * grad_proj).sum().backward()
-    assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
+    assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
 
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]