|
@@ -16,26 +16,53 @@ import asyncio
|
|
|
import ctypes
|
|
|
import heapq
|
|
|
import multiprocessing as mp
|
|
|
+import re
|
|
|
import warnings
|
|
|
-from collections import deque, OrderedDict
|
|
|
+from collections import deque
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
-from itertools import chain
|
|
|
-from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
|
|
|
+from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
|
|
|
|
|
|
import uvloop
|
|
|
|
|
|
from hivemind.client import RemoteExpert
|
|
|
from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
|
|
|
-from hivemind.dht.routing import get_dht_time
|
|
|
+from hivemind.dht.routing import get_dht_time, DHTValue
|
|
|
+from hivemind.dht.storage import ValueWithExpiration
|
|
|
from hivemind.utils import MPFuture, Endpoint, get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
+ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
|
|
|
+UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
|
|
|
+UID_DELIMITER = '.' # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
|
|
|
+FLAT_EXPERT = -1 # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
|
|
|
+UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$') # e.g. ffn_expert.98.76.54 - prefix + some dims
|
|
|
+PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$') # e.g. expert. or ffn.45. (ends with ".")
|
|
|
+# formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
|
|
|
+
|
|
|
+
|
|
|
+def is_valid_uid(maybe_uid: str) -> bool:
|
|
|
+ return bool(UID_PATTERN.fullmatch(maybe_uid))
|
|
|
+
|
|
|
+
|
|
|
+def is_valid_prefix(maybe_prefix: str) -> bool:
|
|
|
+ return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
|
|
|
+
|
|
|
+
|
|
|
+def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
|
|
|
+ """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
|
|
|
+ uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
|
|
|
+ pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
|
|
|
+ return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
|
|
|
+
|
|
|
|
|
|
class DHT(mp.Process):
|
|
|
"""
|
|
|
High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
|
|
|
|
|
|
+ * hivemind servers periodically announce their experts via DHT.declare_experts
|
|
|
+ * trainers find most suitable experts via DHT.find_best_experts
|
|
|
+
|
|
|
:param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
|
|
|
:param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
|
|
|
:param start: if True, automatically starts the background process on creation. Otherwise await manual start
|
|
@@ -55,7 +82,10 @@ class DHT(mp.Process):
|
|
|
A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
|
|
|
When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
|
|
|
For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
|
|
|
- ``"ffn_expert", "ffn_expert.98", "ffn_expert.98.76", ..., "ffn_expert.98.76.54.32.10"``
|
|
|
+ ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
|
|
|
+
|
|
|
+ In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
|
|
|
+ (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
|
|
|
|
|
|
RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
|
|
|
For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
|
|
@@ -63,14 +93,12 @@ class DHT(mp.Process):
|
|
|
However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
|
|
|
In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
|
|
|
It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
|
|
|
- This is done using DHT.first_k_active function. After selecting k best indices along first dimension, MoE moves
|
|
|
- to the second dimension. It can find top-k pairs of indices (e.g. "expert.98.76") that start with one of k first
|
|
|
- indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts.
|
|
|
+
|
|
|
+ After selecting k best indices along first dimension, MoE moves to the second dimension.
|
|
|
+ It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
|
|
|
This beam search explores one additional dimension per step and finds k best experts from across the DHT
|
|
|
- in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
|
|
|
+ in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
|
|
|
"""
|
|
|
- UID_DELIMITER = '.' # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
|
|
|
- # formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
|
|
|
|
|
|
def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
|
|
|
daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
|
|
@@ -129,31 +157,8 @@ class DHT(mp.Process):
|
|
|
def port(self) -> Optional[int]:
|
|
|
return self._port.value if self._port.value != 0 else None
|
|
|
|
|
|
- def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
|
|
|
- return_future=False) -> List[Optional[RemoteExpert]]:
|
|
|
- """
|
|
|
- :param uids: find experts with these ids from across the DHT
|
|
|
- :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
|
- :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
|
|
|
- :returns: a list of [RemoteExpert if found else None]
|
|
|
- """
|
|
|
- assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
- future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
|
|
|
- return future if return_future else future.result()
|
|
|
-
|
|
|
- async def _get_experts(
|
|
|
- self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
|
|
|
- if expiration_time is None:
|
|
|
- expiration_time = get_dht_time()
|
|
|
- num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
|
- response = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
|
|
- # TODO expert_data['expert'] -> namedtuple with meaningful field names
|
|
|
- future.set_result([RemoteExpert(*expert_data.value['expert'].value)
|
|
|
- if expert_data is not None and 'expert' in expert_data.value else None
|
|
|
- for uid, expert_data in response.items()])
|
|
|
-
|
|
|
- def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
|
|
|
+ def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
|
|
|
+ timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
|
|
|
"""
|
|
|
Make experts visible to all DHT peers; update timestamps if declared previously.
|
|
|
|
|
@@ -161,38 +166,151 @@ class DHT(mp.Process):
|
|
|
:param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
|
|
|
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
|
|
:param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
|
|
|
- :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
|
|
|
+ :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
|
|
"""
|
|
|
assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
+ for uid in uids:
|
|
|
+ assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
|
|
|
future, _future = MPFuture.make_pair() if wait else (None, None)
|
|
|
self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
|
|
|
if wait:
|
|
|
return future.result(timeout)
|
|
|
|
|
|
- async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
|
|
|
+ async def _declare_experts(self, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
|
|
|
+ future: Optional[MPFuture]) -> Dict[ExpertUID, bool]:
|
|
|
num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
|
expiration_time = get_dht_time() + self.expiration
|
|
|
- unique_entries: Set[Tuple[str, str]] = set()
|
|
|
- # prefix---v next_dim uid endpoint
|
|
|
- data_to_store: List[Tuple[str, str, List[str, Endpoint]]] = []
|
|
|
- for uid in uids: # first k entries are expert uids themselves
|
|
|
- data_to_store.append((uid, "expert", [uid, endpoint]))
|
|
|
- for uid in uids: # and then, add all prefixes
|
|
|
- uid_parts = uid.split(self.UID_DELIMITER)
|
|
|
- for i in range(len(uid_parts) - 1):
|
|
|
- uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
|
|
|
- if (uid_prefix_i, uid_parts[i + 1]) in unique_entries:
|
|
|
- continue
|
|
|
- unique_entries.add((uid_prefix_i, uid_parts[i + 1]))
|
|
|
- data_to_store.append((uid_prefix_i, uid_parts[i + 1], [uid, endpoint]))
|
|
|
-
|
|
|
- keys, subkeys, values = map(list, zip(*data_to_store))
|
|
|
- store_ok = await node.store_many(keys, values, expiration_time, subkeys=subkeys, num_workers=num_workers)
|
|
|
+ data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
|
|
|
+ for uid in uids:
|
|
|
+ data_to_store[uid, None] = endpoint
|
|
|
+ prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
|
|
|
+ for i in range(prefix.count(UID_DELIMITER) - 1):
|
|
|
+ prefix, last_coord = split_uid(prefix)
|
|
|
+ data_to_store[prefix, last_coord] = [uid, endpoint]
|
|
|
+
|
|
|
+ keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
|
|
|
+ store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
|
|
|
if future is not None:
|
|
|
- future.set_result([store_ok[key, subkey] for key, subkey in zip(keys, subkeys)])
|
|
|
+ future.set_result(store_ok)
|
|
|
+ return store_ok
|
|
|
+
|
|
|
+ def get_experts(self, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
|
|
|
+ return_future: bool = False) -> List[Optional[RemoteExpert]]:
|
|
|
+ """
|
|
|
+ :param uids: find experts with these ids from across the DHT
|
|
|
+ :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
|
+ :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
+ :returns: a list of [RemoteExpert if found else None]
|
|
|
+ """
|
|
|
+ assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ self.pipe.send(('_get_experts', [], dict(uids=list(uids), expiration_time=expiration_time, future=_future)))
|
|
|
+ return future if return_future else future.result()
|
|
|
|
|
|
- def find_best_experts(self, prefix: str, grid_scores: Sequence[Sequence[float]], beam_size: int, *,
|
|
|
- return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
|
|
|
+ async def _get_experts(self, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration],
|
|
|
+ future: Optional[MPFuture] = None) -> List[Optional[RemoteExpert]]:
|
|
|
+ if expiration_time is None:
|
|
|
+ expiration_time = get_dht_time()
|
|
|
+ num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
|
|
|
+ found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
|
|
+
|
|
|
+ experts: List[Optional[RemoteExpert]] = [None] * len(uids)
|
|
|
+ for i, uid in enumerate(uids):
|
|
|
+ if found[uid] is not None and isinstance(found[uid].value, Endpoint):
|
|
|
+ experts[i] = RemoteExpert(uid, found[uid].value)
|
|
|
+ if future:
|
|
|
+ future.set_result(experts)
|
|
|
+ return experts
|
|
|
+
|
|
|
+ def get_initial_beam(self, prefix: ExpertPrefix, scores: Sequence[float], beam_size: int,
|
|
|
+ num_workers: Optional[int] = None, return_future: bool = False
|
|
|
+ ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
|
+ """
|
|
|
+ :param prefix: search for experts whose uids start with this prefix
|
|
|
+ :param scores: prefer suffix coordinates that have highest scores
|
|
|
+ :param beam_size: select this many active suffixes with highest scores
|
|
|
+ :param num_workers: maintain up to this many concurrent DHT searches
|
|
|
+ :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
+ :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
|
|
|
+ """
|
|
|
+ assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ self.pipe.send(('_get_initial_beam', [], dict(prefix=prefix, scores=tuple(scores), beam_size=beam_size,
|
|
|
+ num_workers=num_workers, future=_future)))
|
|
|
+ return future if return_future else future.result()
|
|
|
+
|
|
|
+ async def _get_initial_beam(self, node, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
|
|
|
+ num_workers: Optional[int] = None, future: Optional[MPFuture] = None
|
|
|
+ ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
|
+ num_workers = num_workers or self.max_workers or beam_size
|
|
|
+ beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
|
|
|
+ unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__) # from worst to best
|
|
|
+ pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
|
|
|
+
|
|
|
+ while len(beam) < beam_size and (unattempted_indices or pending_tasks):
|
|
|
+ # dispatch additional tasks
|
|
|
+ while unattempted_indices and len(pending_tasks) < num_workers:
|
|
|
+ next_index = unattempted_indices.pop() # note: this is best unattempted index because of sort order
|
|
|
+ next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
|
|
|
+ pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
|
|
|
+
|
|
|
+ # await the next best prefix to be fetched
|
|
|
+ pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
|
|
|
+ try:
|
|
|
+ maybe_prefix_data = await pending_task
|
|
|
+ if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
|
|
|
+ successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
|
|
|
+ if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
|
|
|
+ and len(match.value) == 2}
|
|
|
+ beam.append((scores[pending_best_index], pending_best_prefix, successors))
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ for _, pending_task in pending_tasks:
|
|
|
+ pending_task.cancel()
|
|
|
+ raise
|
|
|
+ if future:
|
|
|
+ future.set_result(beam)
|
|
|
+ return beam
|
|
|
+
|
|
|
+ def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
|
|
|
+ num_workers: Optional[int] = None, return_future: bool = False
|
|
|
+ ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
|
|
|
+ """
|
|
|
+ :param prefixes: a list of prefix for which to find active successor uids
|
|
|
+ :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
|
|
|
+ :param num_workers: how many parallel workers to use for DHTNode.get_many
|
|
|
+ :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
|
|
|
+ :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
|
|
|
+ :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
|
|
|
+ """
|
|
|
+ assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
|
|
|
+ for prefix in prefixes:
|
|
|
+ assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ self.pipe.send(('_get_active_successors', [], dict(
|
|
|
+ prefixes=list(prefixes), grid_size=grid_size, num_workers=num_workers, future=_future)))
|
|
|
+ return future if return_future else future.result()
|
|
|
+
|
|
|
+ async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
|
|
|
+ num_workers: Optional[int] = None, future: Optional[MPFuture] = None
|
|
|
+ ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
|
|
|
+ grid_size = grid_size or float('inf')
|
|
|
+ num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes))
|
|
|
+ dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
|
|
|
+ successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
|
|
|
+ for prefix, found in dht_responses.items():
|
|
|
+ if found and isinstance(found.value, dict):
|
|
|
+ successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
|
|
|
+ if isinstance(coord, Coordinate) and 0 <= coord < grid_size
|
|
|
+ and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
|
|
|
+ else:
|
|
|
+ successors[prefix] = {}
|
|
|
+ if future:
|
|
|
+ future.set_result(successors)
|
|
|
+ return successors
|
|
|
+
|
|
|
+ def find_best_experts(self, prefix: ExpertPrefix, grid_scores: Sequence[Sequence[float]], beam_size: int,
|
|
|
+ num_workers: Optional[int] = None, return_future: bool = False
|
|
|
+ ) -> Union[List[RemoteExpert], MPFuture]:
|
|
|
"""
|
|
|
Find and return :beam_size: active experts with highest scores, use both local cache and DHT
|
|
|
|
|
@@ -203,174 +321,115 @@ class DHT(mp.Process):
|
|
|
After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
|
|
|
Please note that any queries that fall outside the budget will still be performed in background and cached
|
|
|
for subsequent iterations as long as DHTNode.cache_locally is True
|
|
|
+ :param num_workers: use up to this many concurrent workers to search DHT
|
|
|
:param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
|
|
|
- :param kwargs: extra keyword parameters passed to DHTNode.get_many
|
|
|
:returns: a list that contains *up to* k_best RemoteExpert instances
|
|
|
"""
|
|
|
+ assert len(grid_scores) > 0 and beam_size > 0
|
|
|
+ assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=list(map(tuple, grid_scores)),
|
|
|
- beam_size=beam_size, future=_future, **kwargs)))
|
|
|
+ beam_size=beam_size, num_workers=num_workers, future=_future)))
|
|
|
return future if return_future else future.result()
|
|
|
|
|
|
async def _find_best_experts(
|
|
|
self, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
|
|
|
- max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
|
|
|
- max_workers: Optional[int] = max_workers or self.max_workers or beam_size
|
|
|
+ num_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
|
|
|
+ num_workers = num_workers or min(beam_size, self.max_workers or beam_size)
|
|
|
|
|
|
# form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
|
|
|
- beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = await self._get_initial_beam(
|
|
|
- node, prefix, beam_size, grid_scores[0], num_workers=min(beam_size, max_workers))
|
|
|
- if not beam:
|
|
|
- logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
|
|
|
- return []
|
|
|
- # TODO warn user if indices are out of range on the _last_ level! (rationale: beam search may return <k results)
|
|
|
+ beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await self._get_initial_beam(
|
|
|
+ node, prefix, beam_size, grid_scores[0], min(beam_size, num_workers))
|
|
|
+
|
|
|
+ best_experts_heap: List[Tuple[Score, UidEndpoint]] = [] # max-heap of expert uids/endpoints ordered by scores
|
|
|
+ unique_experts: Set[ExpertUID] = set()
|
|
|
|
|
|
for dim_index in range(1, len(grid_scores) - 1):
|
|
|
- # select beam_size best suffixes from current beam
|
|
|
+ for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
|
|
|
+ if uid_endpoint.uid not in unique_experts:
|
|
|
+ push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
|
|
|
+ push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
|
|
|
+ unique_experts.add(uid_endpoint.uid)
|
|
|
+
|
|
|
+ # form new beam using successors from the current beam
|
|
|
dim_scores = grid_scores[dim_index]
|
|
|
- best_active_pairs: List[Tuple[float, str]] = heapq.nlargest(beam_size, (
|
|
|
- (prefix_score + dim_scores[int(suffix_i)], f"{prefix}{self.UID_DELIMITER}{suffix_i}")
|
|
|
- for prefix_score, prefix, suffixes in beam for suffix_i in suffixes.keys()
|
|
|
- # TODO get rid of str.isdecimal
|
|
|
- if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)))
|
|
|
+ best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
|
|
|
+ (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
|
|
|
+ for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
|
|
|
+ if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
|
|
|
+ _, best_uid_prefixes = zip(*best_active_pairs)
|
|
|
|
|
|
# search DHT for next step suffixes
|
|
|
- _, best_uid_prefixes = zip(*best_active_pairs)
|
|
|
- # TODO Tuple[Dict[str, List[str, Endpoint]], DHTExpiration] -> namedtuple
|
|
|
- dht_responses: Dict[str, Tuple[Dict[str, List[str, Endpoint]], DHTExpiration]] = await node.get_many(
|
|
|
- keys=best_uid_prefixes, num_workers=min(len(best_uid_prefixes), max_workers), **kwargs)
|
|
|
- if all(expiration is None for key, (_, expiration) in dht_responses.items()):
|
|
|
- logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim {dim_index})")
|
|
|
+ successors = await self._get_active_successors(node, best_uid_prefixes, num_workers=num_workers)
|
|
|
+ beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
|
|
|
+ if not beam:
|
|
|
+ logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
|
|
|
break
|
|
|
- beam = [(prefix_score, prefix, dht_responses[prefix][0]) # add suffix dict if it is found
|
|
|
- for prefix_score, prefix in best_active_pairs if dht_responses[prefix][1] is not None]
|
|
|
-
|
|
|
- # select best experts from the final beam
|
|
|
- dim_scores = grid_scores[-1]
|
|
|
- # TODO use heap to harness all results, get rid of five-line expression
|
|
|
- final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, chain((
|
|
|
- (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
|
|
|
- for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
|
|
|
- if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
|
|
|
- ), ((score, *suffixes['expert']) for score, _, suffixes in beam if 'expert' in suffixes)))
|
|
|
- best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
|
|
|
+
|
|
|
+ # add best experts from the final beam
|
|
|
+ for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
|
|
|
+ if uid_endpoint.uid not in unique_experts:
|
|
|
+ push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
|
|
|
+ push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
|
|
|
+ unique_experts.add(uid_endpoint.uid)
|
|
|
+
|
|
|
+ best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
|
|
|
if future is not None:
|
|
|
future.set_result(best_experts)
|
|
|
return best_experts
|
|
|
|
|
|
- def batch_find_best_experts(self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
|
|
|
- return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
|
|
|
+ @staticmethod
|
|
|
+ def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
|
|
|
+ grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
|
|
|
+ """ iterate over all exemplar experts attached to current beam """
|
|
|
+ for score, prefix, suffixes in beam:
|
|
|
+ for next_coord, match in suffixes.items():
|
|
|
+ if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
|
|
|
+ yield score, match
|
|
|
+ elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
|
|
|
+ expert_coords = match.uid.split(UID_DELIMITER)[1:]
|
|
|
+ if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
|
|
|
+ for i, coord in enumerate(expert_coords)):
|
|
|
+ expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
|
|
|
+ yield expert_score, match
|
|
|
+ else:
|
|
|
+ logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
|
|
|
+ else:
|
|
|
+ logger.warning(f"Found incompatible expert UID: {match.uid}")
|
|
|
+
|
|
|
+ def batch_find_best_experts(
|
|
|
+ self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
|
|
|
+ workers_per_sample: Optional[int] = None, return_future=False) -> Union[List[List[RemoteExpert]], MPFuture]:
|
|
|
"""
|
|
|
Find and return :beam_size: active experts with highest scores, use both local cache and DHT
|
|
|
|
|
|
:param prefix: common prefix for all expert uids in grid
|
|
|
:param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
|
|
|
- :type batch_grid_scores: model scores for each example and each grid dimension, list of arrays of shape (batch_size, grid_size[i])
|
|
|
+ :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
|
|
|
:param beam_size: how many best experts should beam search return
|
|
|
After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
|
|
|
Please note that any queries that fall outside the budget will still be performed in background and cached
|
|
|
for subsequent iterations as long as DHTNode.cache_locally is True
|
|
|
+ :param workers_per_sample: use up to this many concurrent workers for every sample in batch
|
|
|
:param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
|
|
|
- :param kwargs: extra keyword parameters passed to DHTNode.get_many
|
|
|
:returns: a list that contains *up to* k_best RemoteExpert instances
|
|
|
"""
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_batch_find_best_experts', [], dict(prefix=prefix, batch_grid_scores=batch_grid_scores,
|
|
|
- beam_size=beam_size, future=_future, **kwargs)))
|
|
|
+ beam_size=beam_size, workers_per_sample=workers_per_sample,
|
|
|
+ future=_future)))
|
|
|
return future if return_future else future.result()
|
|
|
|
|
|
async def _batch_find_best_experts(
|
|
|
self, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]], beam_size: int,
|
|
|
- max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[List[RemoteExpert]]:
|
|
|
+ workers_per_sample: Optional[int] = None, future: Optional[MPFuture] = None) -> List[List[RemoteExpert]]:
|
|
|
|
|
|
- batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))]
|
|
|
- coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, max_workers, **kwargs) for grid_scores in batch_grid_scores]
|
|
|
+ batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
|
|
|
+ for i in range(len(batch_grid_scores[0]))]
|
|
|
+ coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, workers_per_sample)
|
|
|
+ for grid_scores in batch_grid_scores]
|
|
|
|
|
|
best_experts_batch = await asyncio.gather(*coros)
|
|
|
if future is not None:
|
|
|
future.set_result(best_experts_batch)
|
|
|
return best_experts_batch
|
|
|
-
|
|
|
- async def _get_initial_beam(self, node, prefix: str, beam_size: int, scores: Tuple[float, ...], num_workers: int
|
|
|
- ) -> List[Tuple[float, str, Dict[str, List[str]]]]:
|
|
|
- """ Fetch a list of all active level-one prefixes of a given prefix. Used for beam search """
|
|
|
- beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = [] # results will be stored here
|
|
|
- unattempted_indices: List[int] = sorted(range(len(scores)), key=scores.__getitem__) # order: worst to best
|
|
|
- pending_tasks: Deque[Tuple[int, str, asyncio.Task]] = deque() # up to num_workers concurrent get tasks
|
|
|
-
|
|
|
- while len(beam) < beam_size and (unattempted_indices or pending_tasks):
|
|
|
- # dispatch additional tasks
|
|
|
- while unattempted_indices and len(pending_tasks) < num_workers:
|
|
|
- next_index = unattempted_indices.pop() # note: this is best unattempted index because of sort order
|
|
|
- next_best_prefix = f"{prefix}{self.UID_DELIMITER}{next_index}"
|
|
|
- pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
|
|
|
-
|
|
|
- # await the next best prefix to be fetched
|
|
|
- pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
|
|
|
- try:
|
|
|
- maybe_prefix_data = await pending_task
|
|
|
- if maybe_prefix_data is not None:
|
|
|
- beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data.value))
|
|
|
- except asyncio.CancelledError:
|
|
|
- for _, pending_task in pending_tasks:
|
|
|
- pending_task.cancel()
|
|
|
- raise
|
|
|
- return beam
|
|
|
-
|
|
|
- def first_k_active(
|
|
|
- self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
|
|
|
- return_future=False) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[str, RemoteExpert]]]:
|
|
|
- """
|
|
|
- Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
|
|
|
-
|
|
|
- :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
|
|
|
- :param k: return at most *this many* active prefixes
|
|
|
- :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
|
|
|
- :param chunk_size: dispatch this many requests in one task
|
|
|
- :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
|
|
|
- :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
|
|
|
- The keys in the returned dict are ordered same as in uid_prefixes.
|
|
|
- """
|
|
|
- logger.warning("first_k_active is deprecated and will be removed in 0.8.8")
|
|
|
- assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
|
|
|
- future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_first_k_active', [],
|
|
|
- dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
|
|
|
- chunk_size=chunk_size or k, future=_future)))
|
|
|
- return future if return_future else future.result()
|
|
|
-
|
|
|
- async def _first_k_active(
|
|
|
- self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
|
|
|
- num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
|
|
|
- total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
|
|
|
- found: List[Tuple[str, RemoteExpert]] = []
|
|
|
-
|
|
|
- pending_tasks = deque(
|
|
|
- asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
|
|
|
- num_workers=num_workers_per_chunk))
|
|
|
- for chunk_i in range(min(max_prefetch + 1, total_chunks))
|
|
|
- ) # pre-dispatch first task and up to max_prefetch additional tasks
|
|
|
-
|
|
|
- for chunk_i in range(total_chunks):
|
|
|
- # parse task results in chronological order, launch additional tasks on demand
|
|
|
- response = await pending_tasks.popleft()
|
|
|
- for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
|
|
|
- if response[uid_prefix] is not None and len(response[uid_prefix].value) > 0: # found active peer
|
|
|
- found.append((uid_prefix, RemoteExpert(*next(iter(response[uid_prefix].value.values()))[0])))
|
|
|
- # if we found enough active experts, finish immediately
|
|
|
- if len(found) >= k:
|
|
|
- break
|
|
|
- if len(found) >= k:
|
|
|
- break
|
|
|
-
|
|
|
- pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
|
|
|
- if pre_dispatch_chunk_i < total_chunks:
|
|
|
- pending_tasks.append(asyncio.create_task(node.get_many(
|
|
|
- uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
|
|
|
- num_workers=num_workers_per_chunk)))
|
|
|
-
|
|
|
- for task in pending_tasks:
|
|
|
- task.cancel()
|
|
|
-
|
|
|
- # return k active prefixes or as many as we could find
|
|
|
- future.set_result(OrderedDict(found))
|