|
@@ -16,21 +16,21 @@ import asyncio
|
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import warnings
|
|
|
-from typing import List, Optional
|
|
|
+from typing import List, Optional, Sequence
|
|
|
|
|
|
import uvloop
|
|
|
|
|
|
from hivemind.client import RemoteExpert
|
|
|
from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
|
|
|
from hivemind.dht.routing import get_dht_time
|
|
|
-from hivemind.utils import SharedFuture, Endpoint, run_in_background
|
|
|
+from hivemind.utils import MPFuture, Endpoint, run_in_background
|
|
|
|
|
|
|
|
|
class DHT(mp.Process):
|
|
|
"""
|
|
|
A high-level interface to hivemind DHT. Runs a dht node in a background process.
|
|
|
|
|
|
- :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
|
|
|
+ :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
|
|
|
:param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
|
|
|
:param start: if True, automatically starts the background process on creation. Otherwise await manual start
|
|
|
:param daemon: if True, the background process is marked as daemon and automatically terminated after main process
|
|
@@ -42,12 +42,12 @@ class DHT(mp.Process):
|
|
|
EXPIRATION = 120 # anything written to DHT is considered expired after this many seconds
|
|
|
make_key = "{}::{}".format
|
|
|
|
|
|
- def __init__(self, *initial_peers: Endpoint, listen_on: Endpoint = "0.0.0.0:*", start: bool, daemon: bool = True,
|
|
|
- max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
|
|
|
+ def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
|
|
|
+ daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
|
|
|
super().__init__()
|
|
|
self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
|
|
|
self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
|
|
|
- self._port = mp.Value(ctypes.c_int32, 0) # initialized after server starts
|
|
|
+ self._port = mp.Value(ctypes.c_int32, 0) # initialized after dht starts
|
|
|
self.node: Optional[DHTNode] = None # initialized inside self.run only
|
|
|
self._pipe, self.pipe = mp.Pipe(duplex=True)
|
|
|
self.ready = mp.Event()
|
|
@@ -99,11 +99,11 @@ class DHT(mp.Process):
|
|
|
:param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
|
|
|
:returns: a list of [RemoteExpert if found else None]
|
|
|
"""
|
|
|
- future, _future = SharedFuture.make_pair()
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
|
|
|
return future.result()
|
|
|
|
|
|
- def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: SharedFuture):
|
|
|
+ def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: MPFuture):
|
|
|
loop = asyncio.get_event_loop()
|
|
|
expiration = expiration or get_dht_time()
|
|
|
num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
@@ -114,41 +114,40 @@ class DHT(mp.Process):
|
|
|
|
|
|
experts: List[Optional[RemoteExpert]] = [None] * len(uids)
|
|
|
for i, (key, uid) in enumerate(zip(keys, uids)):
|
|
|
- maybe_result, maybe_expiration = response[key]
|
|
|
+ maybe_endpoint, maybe_expiration = response[key]
|
|
|
if maybe_expiration is not None: # if we found a value
|
|
|
- experts[i] = RemoteExpert(uid=uid, host=maybe_result[0], port=maybe_result[1])
|
|
|
+ experts[i] = RemoteExpert(uid=uid, endpoint=maybe_endpoint)
|
|
|
|
|
|
future.set_result(experts)
|
|
|
|
|
|
- def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
|
|
|
+ def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
|
|
|
"""
|
|
|
- Make experts available to DHT; update timestamps if already available
|
|
|
+ Make experts visible to all DHT peers; update timestamps if declared previously.
|
|
|
|
|
|
:param uids: a list of expert ids to update
|
|
|
- :param addr: hostname that can be used to call this expert
|
|
|
- :param port: port that can be used to call this expert
|
|
|
+ :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
|
|
|
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
|
|
:param timeout: waits for the procedure to finish, None means wait indeninitely
|
|
|
:returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
|
|
|
"""
|
|
|
- future, _future = SharedFuture.make_pair() if wait else (None, None)
|
|
|
- self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, future=_future)))
|
|
|
+ assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
+ future, _future = MPFuture.make_pair() if wait else (None, None)
|
|
|
+ self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
|
|
|
if wait:
|
|
|
return future.result(timeout)
|
|
|
|
|
|
- def _declare_experts(self, uids: List[str], addr: str, port: int, future: Optional[SharedFuture]):
|
|
|
+ def _declare_experts(self, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
|
|
|
assert self.node is not None, "This method should only be accessed from inside .run method"
|
|
|
num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
|
loop = asyncio.get_event_loop()
|
|
|
expiration_time = get_dht_time() + self.EXPIRATION
|
|
|
unique_prefixes = set()
|
|
|
- coroutines = []
|
|
|
|
|
|
keys, values = [], []
|
|
|
for uid in uids:
|
|
|
uid_parts = uid.split(self.UID_DELIMETER)
|
|
|
keys.append(self.make_key('expert', uid))
|
|
|
- values.append((addr, port))
|
|
|
+ values.append(endpoint)
|
|
|
unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
|
|
|
|
|
|
for prefix in unique_prefixes:
|
|
@@ -171,12 +170,12 @@ class DHT(mp.Process):
|
|
|
:returns: a list of at most :k: prefixes that have at least one active expert each;
|
|
|
"""
|
|
|
assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
|
|
|
- future, _future = SharedFuture.make_pair()
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_first_k_active', [],
|
|
|
dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
|
|
|
return future.result()
|
|
|
|
|
|
- def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: SharedFuture):
|
|
|
+ def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: MPFuture):
|
|
|
assert self.node is not None, "This method should only be accessed from inside .run method"
|
|
|
max_prefetch = max_prefetch or len(prefixes)
|
|
|
loop = asyncio.get_event_loop()
|