|
@@ -16,6 +16,8 @@ import asyncio
|
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import warnings
|
|
|
+from collections import deque
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
from typing import List, Optional, Sequence
|
|
|
|
|
|
import uvloop
|
|
@@ -23,12 +25,12 @@ import uvloop
|
|
|
from hivemind.client import RemoteExpert
|
|
|
from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
|
|
|
from hivemind.dht.routing import get_dht_time
|
|
|
-from hivemind.utils import MPFuture, Endpoint, run_in_background
|
|
|
+from hivemind.utils import MPFuture, Endpoint
|
|
|
|
|
|
|
|
|
class DHT(mp.Process):
|
|
|
"""
|
|
|
- A high-level interface to hivemind DHT. Runs a dht node in a background process.
|
|
|
+ High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
|
|
|
|
|
|
:param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
|
|
|
:param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
|
|
@@ -36,19 +38,45 @@ class DHT(mp.Process):
|
|
|
:param daemon: if True, the background process is marked as daemon and automatically terminated after main process
|
|
|
:param max_workers: declare_experts and get_experts will use up to this many parallel workers
|
|
|
(but no more than one per key)
|
|
|
+ :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
|
|
|
+ :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
|
|
|
:param kwargs: any other params will be forwarded to DHTNode upon creation
|
|
|
+
|
|
|
+ Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
|
|
|
+ An expert identifier consists of:
|
|
|
+
|
|
|
+ * optional prefix that determines expert role, experiment name, etc.
|
|
|
+ * one or more integers that determine that expert's position in an N-dimensional grid
|
|
|
+
|
|
|
+ A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
|
|
|
+ When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
|
|
|
+ For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
|
|
|
+ ``"ffn_expert", "ffn_expert.98", "ffn_expert.98.76", ..., "ffn_expert.98.76.54.32.10"``
|
|
|
+
|
|
|
+ RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
|
|
|
+ For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
|
|
|
+ This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
|
|
|
+ However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
|
|
|
+ In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
|
|
|
+ It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
|
|
|
+ This is done using DHT.first_k_active function. After selecting k best indices along first dimension, MoE moves
|
|
|
+ to the second dimension. It can find top-k pairs of indices (e.g. "expert.98.76") that start with one of k first
|
|
|
+ indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts.
|
|
|
+ This beam search explores one additional dimension per step and finds k best experts from across the DHT
|
|
|
+ in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
|
|
|
"""
|
|
|
- UID_DELIMETER = '.' # splits expert uids over this delimeter
|
|
|
- EXPIRATION = 120 # anything written to DHT is considered expired after this many seconds
|
|
|
- make_key = "{}::{}".format
|
|
|
+
|
|
|
+ UID_DELIMITER = '.' # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
|
|
|
+ # formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
|
|
|
|
|
|
def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
|
|
|
- daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
|
|
|
+ daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
|
|
|
+ receiver_threads: int = 1, expiration: float = 300, **kwargs):
|
|
|
super().__init__()
|
|
|
self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
|
|
|
- self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
|
|
|
+ self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
|
|
|
+ self.expiration = expiration
|
|
|
self._port = mp.Value(ctypes.c_int32, 0) # initialized after dht starts
|
|
|
- self.node: Optional[DHTNode] = None # initialized inside self.run only
|
|
|
self._pipe, self.pipe = mp.Pipe(duplex=True)
|
|
|
self.ready = mp.Event()
|
|
|
self.daemon = daemon
|
|
@@ -62,16 +90,20 @@ class DHT(mp.Process):
|
|
|
uvloop.install()
|
|
|
loop = asyncio.new_event_loop()
|
|
|
asyncio.set_event_loop(loop)
|
|
|
- self.node: DHTNode = loop.run_until_complete(DHTNode.create(
|
|
|
- initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
|
|
|
- num_workers=self.max_workers or 1, **self.kwargs))
|
|
|
- self._port.value = self.node.port
|
|
|
- run_in_background(loop.run_forever)
|
|
|
- self.ready.set()
|
|
|
+ pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
|
|
|
+
|
|
|
+ async def _run():
|
|
|
+ node = await DHTNode.create(
|
|
|
+ initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
|
|
|
+ num_workers=self.max_workers or 1, **self.kwargs)
|
|
|
+ self._port.value = node.port
|
|
|
+ self.ready.set()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
|
|
|
+ asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
|
|
|
|
|
|
- while True:
|
|
|
- method, args, kwargs = self._pipe.recv()
|
|
|
- getattr(self, method)(*args, **kwargs)
|
|
|
+ loop.run_until_complete(_run())
|
|
|
|
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
|
"""
|
|
@@ -85,7 +117,7 @@ class DHT(mp.Process):
|
|
|
def shutdown(self) -> None:
|
|
|
""" Shuts down the dht process """
|
|
|
if self.is_alive():
|
|
|
- self.kill()
|
|
|
+ self.terminate()
|
|
|
else:
|
|
|
warnings.warn("DHT shutdown has no effect: dht process is already not alive")
|
|
|
|
|
@@ -93,32 +125,27 @@ class DHT(mp.Process):
|
|
|
def port(self) -> Optional[int]:
|
|
|
return self._port.value if self._port.value != 0 else None
|
|
|
|
|
|
- def get_experts(self, uids: List[str], expiration=None) -> List[Optional[RemoteExpert]]:
|
|
|
+ def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
|
|
|
+ wait=True) -> List[Optional[RemoteExpert]]:
|
|
|
"""
|
|
|
:param uids: find experts with these ids from across the DHT
|
|
|
- :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
|
|
|
+ :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
|
+ :param wait: if True (default), return when experts are returned. Otherwise return a Future.
|
|
|
:returns: a list of [RemoteExpert if found else None]
|
|
|
"""
|
|
|
+ assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
|
|
|
- return future.result()
|
|
|
+ self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
|
|
|
+ return future.result() if wait else future
|
|
|
|
|
|
- def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: MPFuture):
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- expiration = expiration or get_dht_time()
|
|
|
+ async def _get_experts(
|
|
|
+ self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
|
|
|
+ if expiration_time is None:
|
|
|
+ expiration_time = get_dht_time()
|
|
|
num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
|
- keys = [self.make_key('expert', uid) for uid in uids]
|
|
|
-
|
|
|
- response = asyncio.run_coroutine_threadsafe(
|
|
|
- self.node.get_many(keys, expiration, num_workers=num_workers), loop).result()
|
|
|
-
|
|
|
- experts: List[Optional[RemoteExpert]] = [None] * len(uids)
|
|
|
- for i, (key, uid) in enumerate(zip(keys, uids)):
|
|
|
- maybe_endpoint, maybe_expiration = response[key]
|
|
|
- if maybe_expiration is not None: # if we found a value
|
|
|
- experts[i] = RemoteExpert(uid=uid, endpoint=maybe_endpoint)
|
|
|
-
|
|
|
- future.set_result(experts)
|
|
|
+ response = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
|
|
+ future.set_result([RemoteExpert(uid, maybe_endpoint) if maybe_expiration_time else None
|
|
|
+ for uid, (maybe_endpoint, maybe_expiration_time) in response.items()])
|
|
|
|
|
|
def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
|
|
|
"""
|
|
@@ -136,69 +163,70 @@ class DHT(mp.Process):
|
|
|
if wait:
|
|
|
return future.result(timeout)
|
|
|
|
|
|
- def _declare_experts(self, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
|
|
|
- assert self.node is not None, "This method should only be accessed from inside .run method"
|
|
|
+ async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
|
|
|
num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- expiration_time = get_dht_time() + self.EXPIRATION
|
|
|
- unique_prefixes = set()
|
|
|
+ expiration_time = get_dht_time() + self.expiration
|
|
|
|
|
|
- keys, values = [], []
|
|
|
+ data_to_store = {}
|
|
|
for uid in uids:
|
|
|
- uid_parts = uid.split(self.UID_DELIMETER)
|
|
|
- keys.append(self.make_key('expert', uid))
|
|
|
- values.append(endpoint)
|
|
|
- unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
|
|
|
-
|
|
|
- for prefix in unique_prefixes:
|
|
|
- keys.append(self.make_key('prefix', prefix))
|
|
|
- values.append(True)
|
|
|
-
|
|
|
- store_ok = asyncio.run_coroutine_threadsafe(
|
|
|
- self.node.store_many(keys, values, expiration_time, num_workers=num_workers), loop
|
|
|
- ).result()
|
|
|
+ uid_parts = uid.split(self.UID_DELIMITER)
|
|
|
+ for i in range(len(uid_parts)):
|
|
|
+ uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
|
|
|
+ data_to_store[uid_prefix_i] = endpoint
|
|
|
+
|
|
|
+ store_keys, store_values = zip(*data_to_store.items())
|
|
|
+ store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
|
|
|
if future is not None:
|
|
|
- future.set_result([store_ok[key] for key in keys])
|
|
|
+ future.set_result([store_ok[key] for key in data_to_store.keys()])
|
|
|
|
|
|
- def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
|
|
|
+ def first_k_active(self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None):
|
|
|
"""
|
|
|
Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
|
|
|
|
|
|
- :param prefixes: a list of uid prefixes ordered from highest to lowest priority
|
|
|
+ :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
|
|
|
:param k: return at most *this many* active prefixes
|
|
|
- :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
|
|
|
+ :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
|
|
|
+ :param chunk_size: dispatch this many requests in one task
|
|
|
:returns: a list of at most :k: prefixes that have at least one active expert each;
|
|
|
"""
|
|
|
- assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
|
|
|
+ assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_first_k_active', [],
|
|
|
- dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
|
|
|
+ dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
|
|
|
+ chunk_size=chunk_size or k, future=_future)))
|
|
|
return future.result()
|
|
|
|
|
|
- def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: MPFuture):
|
|
|
- assert self.node is not None, "This method should only be accessed from inside .run method"
|
|
|
- max_prefetch = max_prefetch or len(prefixes)
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- lookup_prefetch = [asyncio.run_coroutine_threadsafe(self.node.get(self.make_key('prefix', prefix)), loop)
|
|
|
- for prefix in prefixes[:max_prefetch]]
|
|
|
+ async def _first_k_active(
|
|
|
+ self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
|
|
|
+ num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
|
|
|
+ total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
|
|
|
active_prefixes = []
|
|
|
|
|
|
- for i, prefix in enumerate(prefixes):
|
|
|
- _, maybe_expiration = lookup_prefetch[i].result()
|
|
|
-
|
|
|
- if maybe_expiration is not None:
|
|
|
- active_prefixes.append(prefix)
|
|
|
- if len(active_prefixes) >= k:
|
|
|
- future.set_result(active_prefixes)
|
|
|
- for task in lookup_prefetch[i:]:
|
|
|
- task.cancel()
|
|
|
- return
|
|
|
-
|
|
|
- # pre-dispatch the next request in line
|
|
|
- if len(lookup_prefetch) < len(prefixes):
|
|
|
- lookup_prefetch.append(
|
|
|
- asyncio.run_coroutine_threadsafe(
|
|
|
- self.node.get(self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop))
|
|
|
-
|
|
|
- # could not find enough active prefixes; return what we can
|
|
|
+ pending_tasks = deque(
|
|
|
+ asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
|
|
|
+ num_workers=num_workers_per_chunk))
|
|
|
+ for chunk_i in range(min(max_prefetch + 1, total_chunks))
|
|
|
+ ) # pre-dispatch first task and up to max_prefetch additional tasks
|
|
|
+
|
|
|
+ for chunk_i in range(total_chunks):
|
|
|
+ # parse task results in chronological order, launch additional tasks on demand
|
|
|
+ response = await pending_tasks.popleft()
|
|
|
+ for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
|
|
|
+ if response[uid_prefix][1] is not None: # found active peer
|
|
|
+ active_prefixes.append(uid_prefix)
|
|
|
+ # if we found enough active experts, finish immediately
|
|
|
+ if len(active_prefixes) >= k:
|
|
|
+ break
|
|
|
+ if len(active_prefixes) >= k:
|
|
|
+ for task in pending_tasks:
|
|
|
+ task.cancel()
|
|
|
+ break
|
|
|
+
|
|
|
+ pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
|
|
|
+ if pre_dispatch_chunk_i < total_chunks:
|
|
|
+ pending_tasks.append(asyncio.create_task(node.get_many(
|
|
|
+ uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
|
|
|
+ num_workers=num_workers_per_chunk)))
|
|
|
+
|
|
|
+ # return k active prefixes or as many as we could find
|
|
|
future.set_result(active_prefixes)
|