|
@@ -7,7 +7,7 @@ import logging
|
|
|
import random
|
|
|
import threading
|
|
|
import time
|
|
|
-from typing import Any, Collection, Dict, List, Optional, Sequence, Union
|
|
|
+from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
import dijkstar
|
|
@@ -38,6 +38,7 @@ class SequenceManagerConfig:
|
|
|
|
|
|
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
|
|
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
|
|
+ blocked_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, do not use these servers
|
|
|
use_server_to_server: bool = True # Use direct server-to-server communication
|
|
|
|
|
|
connect_timeout: float = 5 # timeout for opening a connection
|
|
@@ -116,6 +117,9 @@ class RemoteSequenceManager:
|
|
|
self._thread_start_lock = threading.Lock()
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
|
|
|
+ self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
|
|
|
+ self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
|
|
|
+
|
|
|
self.ping_aggregator = PingAggregator(dht)
|
|
|
|
|
|
if state.banned_peers is None:
|
|
@@ -128,6 +132,23 @@ class RemoteSequenceManager:
|
|
|
self._thread.ready.set() # no need to await the first dht fetch
|
|
|
self._need_latest_infos = True
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
|
|
|
+ if peer_ids is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ result = set()
|
|
|
+ for peer_id in peer_ids:
|
|
|
+ if isinstance(peer_id, PeerID):
|
|
|
+ result.add(peer_id)
|
|
|
+ elif isinstance(peer_id, str):
|
|
|
+ result.add(PeerID.from_base58(peer_id))
|
|
|
+ else:
|
|
|
+ raise TypeError(
|
|
|
+ f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
|
|
|
+ )
|
|
|
+ return result
|
|
|
+
|
|
|
def make_sequence(
|
|
|
self,
|
|
|
start_index: int = 0,
|
|
@@ -341,13 +362,13 @@ class RemoteSequenceManager:
|
|
|
if not block_info:
|
|
|
continue
|
|
|
|
|
|
- # Apply whitelist, if defined
|
|
|
- if self.config.allowed_servers is not None:
|
|
|
- block_info.servers = {
|
|
|
- peer_id: server_info
|
|
|
- for peer_id, server_info in block_info.servers.items()
|
|
|
- if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
|
|
|
- }
|
|
|
+ # Apply allow and block lists
|
|
|
+ block_info.servers = {
|
|
|
+ peer_id: server_info
|
|
|
+ for peer_id, server_info in block_info.servers.items()
|
|
|
+ if (self.allowed_servers is None or peer_id in self.allowed_servers)
|
|
|
+ and (self.blocked_servers is None or peer_id not in self.blocked_servers)
|
|
|
+ }
|
|
|
|
|
|
# Remove temporarily banned peers, unless there are no peers left
|
|
|
valid_servers = {
|