|
@@ -1,36 +1,57 @@
|
|
|
|
+"""
|
|
|
|
+This sub-module implements a node in a Kademlia-based DHT. The code is organized as follows:
|
|
|
|
+ * class DHT (below) - high-level class for model training. Runs DHTNode in a background process.
|
|
|
|
+ * class DHTNode (node.py) - an asyncio implementation of dht server, stores AND gets keys. Asyncio-based.
|
|
|
|
+ * class KademliaProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
|
|
|
|
+
|
|
|
|
+The code in this module is a modified version of https://github.com/bmuller/kademlia
|
|
|
|
+Brian, if you're reading this: THANK YOU! you're awesome :)
|
|
|
|
+"""
|
|
import asyncio
|
|
import asyncio
|
|
-import datetime
|
|
|
|
import multiprocessing as mp
|
|
import multiprocessing as mp
|
|
import warnings
|
|
import warnings
|
|
from typing import Tuple, List, Optional
|
|
from typing import Tuple, List, Optional
|
|
|
|
|
|
-from kademlia.network import Server
|
|
|
|
|
|
+from .node import DHTNode, DHTID, DHTExpiration
|
|
|
|
+from .routing import get_dht_time
|
|
|
|
|
|
-from hivemind.client import RemoteExpert
|
|
|
|
-from hivemind.utils import run_forever, SharedFuture, PickleSerializer
|
|
|
|
|
|
+from ..client import RemoteExpert
|
|
|
|
+from ..utils import SharedFuture, find_open_port, Hostname, Port, run_in_background
|
|
|
|
|
|
|
|
|
|
-class DHTNode(mp.Process):
|
|
|
|
|
|
+class DHT(mp.Process):
|
|
|
|
+ """
|
|
|
|
+ A high-level interface to hivemind DHT. Runs a dht node in a background process.
|
|
|
|
+ :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
|
|
|
|
+ :param port: a port where DHT will listen to incoming connections. Defaults to hivemind.utils.find_open_port
|
|
|
|
+ :param start: if True, automatically starts the background process on creation. Otherwise await manual start
|
|
|
|
+ :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
|
|
|
|
+ :param node_params: any other params will be forwarded to DHTNode upon creation
|
|
|
|
+ """
|
|
UID_DELIMETER = '.' # splits expert uids over this delimeter
|
|
UID_DELIMETER = '.' # splits expert uids over this delimeter
|
|
- HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds*
|
|
|
|
|
|
+ EXPIRATION = 120 # anything written to DHT is considered expired after this many seconds
|
|
make_key = "{}::{}".format
|
|
make_key = "{}::{}".format
|
|
|
|
|
|
- def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False, daemon=True):
|
|
|
|
|
|
+ def __init__(self, *initial_peers: Tuple[Hostname, Port], port: Optional[Port] = None,
|
|
|
|
+ start: bool, daemon: bool = True, **node_params):
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.port, self.initial_peers = port, initial_peers
|
|
|
|
|
|
+ port = find_open_port() if port is None else port
|
|
|
|
+ self.node: Optional[DHTNode] = None # to be initialized in self.run
|
|
|
|
+ self.port, self.initial_peers, self.node_params = port, initial_peers, node_params
|
|
self._pipe, self.pipe = mp.Pipe(duplex=False)
|
|
self._pipe, self.pipe = mp.Pipe(duplex=False)
|
|
self.ready = mp.Event()
|
|
self.ready = mp.Event()
|
|
- self.server = Server()
|
|
|
|
self.daemon = daemon
|
|
self.daemon = daemon
|
|
if start:
|
|
if start:
|
|
self.run_in_background(await_ready=True)
|
|
self.run_in_background(await_ready=True)
|
|
|
|
|
|
def run(self) -> None:
|
|
def run(self) -> None:
|
|
|
|
+ if asyncio.get_event_loop().is_running():
|
|
|
|
+ asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
|
|
loop = asyncio.new_event_loop()
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
asyncio.set_event_loop(loop)
|
|
- loop.run_until_complete(self.server.listen(self.port))
|
|
|
|
- loop.run_until_complete(self.server.bootstrap(self.initial_peers))
|
|
|
|
- run_forever(loop.run_forever)
|
|
|
|
|
|
+
|
|
|
|
+ self.node = DHTNode(initial_peers=list(self.initial_peers), port=self.port, **self.node_params)
|
|
|
|
+ run_in_background(loop.run_forever)
|
|
self.ready.set()
|
|
self.ready.set()
|
|
|
|
|
|
while True:
|
|
while True:
|
|
@@ -39,7 +60,7 @@ class DHTNode(mp.Process):
|
|
|
|
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
"""
|
|
"""
|
|
- Starts DHTNode in a background process. if await_ready, this method will wait until background dht
|
|
|
|
|
|
+ Starts DHT in a background process. if await_ready, this method will wait until background dht
|
|
is ready to process incoming requests or for :timeout: seconds max.
|
|
is ready to process incoming requests or for :timeout: seconds max.
|
|
"""
|
|
"""
|
|
self.start()
|
|
self.start()
|
|
@@ -53,98 +74,106 @@ class DHTNode(mp.Process):
|
|
else:
|
|
else:
|
|
warnings.warn("DHT shutdown has no effect: dht process is already not alive")
|
|
warnings.warn("DHT shutdown has no effect: dht process is already not alive")
|
|
|
|
|
|
- def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
|
|
|
|
- """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
|
|
|
|
|
|
+ def get_experts(self, uids: List[str], expiration=None) -> List[Optional[RemoteExpert]]:
|
|
|
|
+ """
|
|
|
|
+ :param uids: find experts with these ids from across the DHT
|
|
|
|
+ :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
|
|
|
|
+ :returns: a list of [RemoteExpert if found else None]
|
|
|
|
+ """
|
|
future, _future = SharedFuture.make_pair()
|
|
future, _future = SharedFuture.make_pair()
|
|
- self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
|
|
|
|
|
|
+ self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
|
|
return future.result()
|
|
return future.result()
|
|
|
|
|
|
- def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
|
|
|
|
|
|
+ def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: SharedFuture):
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
+ expiration = expiration or get_dht_time()
|
|
|
|
+
|
|
lookup_futures = [asyncio.run_coroutine_threadsafe(
|
|
lookup_futures = [asyncio.run_coroutine_threadsafe(
|
|
- self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
|
|
|
|
- current_time = datetime.datetime.now()
|
|
|
|
|
|
+ self.node.get(self.make_key('expert', uid), expiration), loop) for uid in uids]
|
|
|
|
|
|
- experts = [None] * len(uids)
|
|
|
|
|
|
+ experts: List[Optional[RemoteExpert]] = [None] * len(uids)
|
|
for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
|
|
for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
|
|
- if lookup.result() is not None:
|
|
|
|
- (host, port), timestamp = PickleSerializer.loads(lookup.result())
|
|
|
|
- if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
|
|
|
|
- experts[i] = RemoteExpert(uid=uid, host=host, port=port)
|
|
|
|
|
|
+ maybe_result, maybe_expiration = lookup.result()
|
|
|
|
+ if maybe_expiration is not None: # if we found a value
|
|
|
|
+ experts[i] = RemoteExpert(uid=uid, host=maybe_result[0], port=maybe_result[1])
|
|
|
|
|
|
future.set_result(experts)
|
|
future.set_result(experts)
|
|
|
|
|
|
- def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
|
|
|
|
|
|
+ def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
|
|
"""
|
|
"""
|
|
Make experts available to DHT; update timestamps if already available
|
|
Make experts available to DHT; update timestamps if already available
|
|
:param uids: a list of expert ids to update
|
|
:param uids: a list of expert ids to update
|
|
:param addr: hostname that can be used to call this expert
|
|
:param addr: hostname that can be used to call this expert
|
|
:param port: port that can be used to call this expert
|
|
:param port: port that can be used to call this expert
|
|
- :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
|
|
|
|
|
|
+ :param wait: if True, awaits for declaration to finish, otherwise runs in background
|
|
|
|
+ :param timeout: waits for the procedure to finish, None means wait indeninitely
|
|
|
|
+ :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
|
|
"""
|
|
"""
|
|
- done_event = mp.Event() if wait_timeout else None
|
|
|
|
- self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
|
|
|
|
- if done_event is not None:
|
|
|
|
- done_event.wait(wait_timeout)
|
|
|
|
|
|
+ future, _future = SharedFuture.make_pair() if wait else (None, None)
|
|
|
|
+ self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, future=_future)))
|
|
|
|
+ if wait:
|
|
|
|
+ return future.result(timeout)
|
|
|
|
|
|
- def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
|
|
|
|
|
|
+ def _declare_experts(self, uids: List[str], addr: str, port: int, future: Optional[SharedFuture]):
|
|
|
|
+ assert self.node is not None, "This method should only be accessed from inside .run method"
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
- timestamp = datetime.datetime.now()
|
|
|
|
- expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
|
|
|
|
- prefix_metadata = PickleSerializer.dumps(timestamp)
|
|
|
|
-
|
|
|
|
|
|
+ expiration_time = get_dht_time() + self.EXPIRATION
|
|
unique_prefixes = set()
|
|
unique_prefixes = set()
|
|
|
|
+ coroutines = []
|
|
|
|
|
|
for uid in uids:
|
|
for uid in uids:
|
|
- asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
|
|
|
|
|
|
+ coroutines.append(asyncio.run_coroutine_threadsafe(
|
|
|
|
+ self.node.store(self.make_key('expert', uid), value=(addr, port),
|
|
|
|
+ expiration_time=expiration_time),
|
|
|
|
+ loop))
|
|
uid_parts = uid.split(self.UID_DELIMETER)
|
|
uid_parts = uid.split(self.UID_DELIMETER)
|
|
unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
|
|
unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
|
|
|
|
|
|
for prefix in unique_prefixes:
|
|
for prefix in unique_prefixes:
|
|
- asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
|
|
|
|
|
|
+ coroutines.append(asyncio.run_coroutine_threadsafe(
|
|
|
|
+ self.node.store(self.make_key('prefix', prefix), True, expiration_time), loop))
|
|
|
|
|
|
- if done_event is not None:
|
|
|
|
- done_event.set()
|
|
|
|
|
|
+ if future is not None:
|
|
|
|
+ future.set_result([coro.result() for coro in coroutines]) # wait for all coroutings to finish
|
|
|
|
|
|
- def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
|
|
|
|
|
|
+ def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
|
|
"""
|
|
"""
|
|
Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
|
|
Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
|
|
:param prefixes: a list of uid prefixes ordered from highest to lowest priority
|
|
:param prefixes: a list of uid prefixes ordered from highest to lowest priority
|
|
:param k: return at most *this many* active prefixes
|
|
:param k: return at most *this many* active prefixes
|
|
- :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
|
|
|
|
:param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
|
|
:param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
|
|
:returns: a list of at most :k: prefixes that have at least one active expert each;
|
|
:returns: a list of at most :k: prefixes that have at least one active expert each;
|
|
"""
|
|
"""
|
|
|
|
+ assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
|
|
future, _future = SharedFuture.make_pair()
|
|
future, _future = SharedFuture.make_pair()
|
|
- self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
|
|
|
|
- max_prefetch=max_prefetch or k, future=_future)))
|
|
|
|
|
|
+ self.pipe.send(('_first_k_active', [],
|
|
|
|
+ dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
|
|
return future.result()
|
|
return future.result()
|
|
|
|
|
|
- def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
|
|
|
|
|
|
+ def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: SharedFuture):
|
|
|
|
+ assert self.node is not None, "This method should only be accessed from inside .run method"
|
|
|
|
+ max_prefetch = max_prefetch or len(prefixes)
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
- lookup_prefetch = [asyncio.run_coroutine_threadsafe(
|
|
|
|
- self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
|
|
|
|
- current_time = datetime.datetime.now()
|
|
|
|
-
|
|
|
|
|
|
+ lookup_prefetch = [asyncio.run_coroutine_threadsafe(self.node.get(self.make_key('prefix', prefix)), loop)
|
|
|
|
+ for prefix in prefixes[:max_prefetch]]
|
|
active_prefixes = []
|
|
active_prefixes = []
|
|
|
|
|
|
for i, prefix in enumerate(prefixes):
|
|
for i, prefix in enumerate(prefixes):
|
|
- lookup = lookup_prefetch[i]
|
|
|
|
|
|
+ _, maybe_expiration = lookup_prefetch[i].result()
|
|
|
|
|
|
- if lookup.result() is not None:
|
|
|
|
- timestamp = PickleSerializer.loads(lookup.result())
|
|
|
|
- if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
|
|
|
|
- active_prefixes.append(prefix)
|
|
|
|
- if len(active_prefixes) >= k:
|
|
|
|
- future.set_result(active_prefixes)
|
|
|
|
- return
|
|
|
|
|
|
+ if maybe_expiration is not None:
|
|
|
|
+ active_prefixes.append(prefix)
|
|
|
|
+ if len(active_prefixes) >= k:
|
|
|
|
+ future.set_result(active_prefixes)
|
|
|
|
+ for task in lookup_prefetch[i:]:
|
|
|
|
+ task.cancel()
|
|
|
|
+ return
|
|
|
|
|
|
# pre-dispatch the next request in line
|
|
# pre-dispatch the next request in line
|
|
if len(lookup_prefetch) < len(prefixes):
|
|
if len(lookup_prefetch) < len(prefixes):
|
|
lookup_prefetch.append(
|
|
lookup_prefetch.append(
|
|
- asyncio.run_coroutine_threadsafe(self.server.get(
|
|
|
|
- self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
|
|
|
|
- )
|
|
|
|
|
|
+ asyncio.run_coroutine_threadsafe(
|
|
|
|
+ self.node.get(self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop))
|
|
|
|
|
|
# could not find enough active prefixes; return what we can
|
|
# could not find enough active prefixes; return what we can
|
|
future.set_result(active_prefixes)
|
|
future.set_result(active_prefixes)
|