|
@@ -5,25 +5,21 @@ from functools import partial
|
|
|
from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
|
|
|
|
from hivemind.dht import DHT, DHTExpiration, DHTNode
|
|
|
-from hivemind.moe.client.expert import (
|
|
|
- RemoteExpert,
|
|
|
- RemoteExpertInfo,
|
|
|
- batch_create_remote_experts,
|
|
|
- create_remote_experts,
|
|
|
-)
|
|
|
-from hivemind.moe.server.expert_uid import (
|
|
|
+from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
|
|
|
+from hivemind.moe.expert_uid import (
|
|
|
FLAT_EXPERT,
|
|
|
PREFIX_PATTERN,
|
|
|
UID_DELIMITER,
|
|
|
Coordinate,
|
|
|
+ ExpertInfo,
|
|
|
ExpertPrefix,
|
|
|
ExpertUID,
|
|
|
Score,
|
|
|
- UidEndpoint,
|
|
|
is_valid_prefix,
|
|
|
+ is_valid_uid,
|
|
|
)
|
|
|
-from hivemind.p2p import PeerInfo
|
|
|
-from hivemind.utils import MPFuture, get_dht_time, get_logger
|
|
|
+from hivemind.p2p import PeerID
|
|
|
+from hivemind.utils import MPFuture, ValueWithExpiration, get_dht_time, get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -100,7 +96,7 @@ class MoEBeamSearcher:
|
|
|
|
|
|
def get_initial_beam(
|
|
|
self, scores: Sequence[float], beam_size: int, return_future: bool = False
|
|
|
- ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
|
+ ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
|
|
|
"""
|
|
|
:param scores: prefer suffix coordinates that have highest scores
|
|
|
:param beam_size: select this many active suffixes with highest scores
|
|
@@ -130,9 +126,9 @@ class MoEBeamSearcher:
|
|
|
negative_caching: bool,
|
|
|
cache_expiration: DHTExpiration,
|
|
|
num_workers: Optional[int] = None,
|
|
|
- ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
|
+ ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
|
|
|
num_workers = num_workers or dht.num_workers or beam_size
|
|
|
- beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
|
|
|
+ beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = []
|
|
|
unattempted_indices: List[Coordinate] = sorted(
|
|
|
range(len(scores)), key=scores.__getitem__
|
|
|
) # from worst to best
|
|
@@ -150,13 +146,7 @@ class MoEBeamSearcher:
|
|
|
try:
|
|
|
maybe_prefix_data = await pending_task
|
|
|
if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
|
|
|
- successors = {
|
|
|
- coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
|
|
|
- for coord, match in maybe_prefix_data.value.items()
|
|
|
- if isinstance(coord, Coordinate)
|
|
|
- and isinstance(getattr(match, "value", None), list)
|
|
|
- and len(match.value) == 2
|
|
|
- }
|
|
|
+ successors = MoEBeamSearcher._select_valid_entries(maybe_prefix_data)
|
|
|
if successors:
|
|
|
beam.append((scores[pending_best_index], pending_best_prefix, successors))
|
|
|
elif maybe_prefix_data is None and negative_caching:
|
|
@@ -178,7 +168,7 @@ class MoEBeamSearcher:
|
|
|
|
|
|
def get_active_successors(
|
|
|
self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
|
|
|
- ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
|
|
|
+ ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
|
|
|
"""
|
|
|
:param prefixes: a list of prefix for which to find active successor uids
|
|
|
:param grid_size: if specified, only return successors if ther are in range [0, grid_size)
|
|
@@ -201,6 +191,22 @@ class MoEBeamSearcher:
|
|
|
return_future=return_future,
|
|
|
)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _select_valid_entries(entry: ValueWithExpiration, grid_size: Optional[int] = None):
|
|
|
+ if not isinstance(entry, ValueWithExpiration) or not isinstance(entry.value, dict):
|
|
|
+ return {}
|
|
|
+ return {
|
|
|
+ coord: ExpertInfo(uid=match.value[0], peer_id=PeerID.from_base58(match.value[1]))
|
|
|
+ for coord, match in entry.value.items()
|
|
|
+ if isinstance(coord, Coordinate)
|
|
|
+ and (grid_size is None or 0 <= coord < grid_size)
|
|
|
+ and isinstance(match, ValueWithExpiration)
|
|
|
+ and isinstance(match.value, tuple)
|
|
|
+ and len(match.value) == 2
|
|
|
+ and is_valid_uid(match.value[0])
|
|
|
+ and isinstance(match.value[1], str)
|
|
|
+ }
|
|
|
+
|
|
|
@staticmethod
|
|
|
async def _get_active_successors(
|
|
|
dht: DHT,
|
|
@@ -210,28 +216,18 @@ class MoEBeamSearcher:
|
|
|
negative_caching: bool,
|
|
|
cache_expiration: DHTExpiration,
|
|
|
num_workers: Optional[int] = None,
|
|
|
- ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
|
|
|
+ ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
|
|
|
grid_size = grid_size or float("inf")
|
|
|
num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
|
|
|
dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
|
|
|
- successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
|
|
|
+ successors: Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]] = {}
|
|
|
for prefix, found in dht_responses.items():
|
|
|
- if found and isinstance(found.value, dict):
|
|
|
- successors[prefix] = {
|
|
|
- coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
|
|
|
- for coord, match in found.value.items()
|
|
|
- if isinstance(coord, Coordinate)
|
|
|
- and 0 <= coord < grid_size
|
|
|
- and isinstance(getattr(match, "value", None), list)
|
|
|
- and len(match.value) == 2
|
|
|
- }
|
|
|
- else:
|
|
|
- successors[prefix] = {}
|
|
|
- if found is None and negative_caching:
|
|
|
- logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
|
|
|
- asyncio.create_task(
|
|
|
- node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
|
|
|
- )
|
|
|
+ successors[prefix] = MoEBeamSearcher._select_valid_entries(found, grid_size)
|
|
|
+ if not successors[prefix] and negative_caching:
|
|
|
+ logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
|
|
|
+ asyncio.create_task(
|
|
|
+ node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
|
|
|
+ )
|
|
|
return successors
|
|
|
|
|
|
def find_best_experts(
|
|
@@ -246,7 +242,6 @@ class MoEBeamSearcher:
|
|
|
After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
|
|
|
Please note that any queries that fall outside the budget will still be performed in background and cached
|
|
|
for subsequent iterations as long as DHTNode.cache_locally is True
|
|
|
- :param num_workers: use up to this many concurrent workers to search DHT
|
|
|
:param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
|
|
|
:returns: a list that contains *up to* k_best RemoteExpert instances
|
|
|
"""
|
|
@@ -263,7 +258,6 @@ class MoEBeamSearcher:
|
|
|
),
|
|
|
return_future,
|
|
|
)
|
|
|
-
|
|
|
return create_remote_experts(result, self.dht, return_future)
|
|
|
|
|
|
@classmethod
|
|
@@ -277,23 +271,23 @@ class MoEBeamSearcher:
|
|
|
negative_caching: bool,
|
|
|
cache_expiration: DHTExpiration,
|
|
|
num_workers: Optional[int] = None,
|
|
|
- ) -> List[RemoteExpertInfo]:
|
|
|
+ ) -> List[ExpertInfo]:
|
|
|
num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
|
|
|
|
|
|
# form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
|
|
|
- beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
|
|
|
+ beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = await cls._get_initial_beam(
|
|
|
dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
|
|
|
)
|
|
|
|
|
|
- best_experts_heap: List[Tuple[Score, UidEndpoint]] = [] # max-heap of expert uids/endpoints ordered by scores
|
|
|
+ best_experts_heap: List[Tuple[Score, ExpertInfo]] = [] # max-heap of expert infos ordered by scores
|
|
|
unique_experts: Set[ExpertUID] = set()
|
|
|
|
|
|
for dim_index in range(1, len(grid_scores) - 1):
|
|
|
- for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
|
|
|
- if uid_endpoint.uid not in unique_experts:
|
|
|
+ for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
|
|
|
+ if expert_info.uid not in unique_experts:
|
|
|
push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
|
|
|
- push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
|
|
|
- unique_experts.add(uid_endpoint.uid)
|
|
|
+ push_and_maybe_pop(best_experts_heap, (score, expert_info))
|
|
|
+ unique_experts.add(expert_info.uid)
|
|
|
|
|
|
# form new beam using successors from the current beam
|
|
|
dim_scores = grid_scores[dim_index]
|
|
@@ -306,6 +300,7 @@ class MoEBeamSearcher:
|
|
|
if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
|
|
|
),
|
|
|
)
|
|
|
+
|
|
|
_, best_uid_prefixes = zip(*best_active_pairs)
|
|
|
|
|
|
# search DHT for next step suffixes
|
|
@@ -324,22 +319,18 @@ class MoEBeamSearcher:
|
|
|
break
|
|
|
|
|
|
# add best experts from the final beam
|
|
|
- for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
|
|
|
- if uid_endpoint.uid not in unique_experts:
|
|
|
+ for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
|
|
|
+ if expert_info.uid not in unique_experts:
|
|
|
push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
|
|
|
- push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
|
|
|
- unique_experts.add(uid_endpoint.uid)
|
|
|
+ push_and_maybe_pop(best_experts_heap, (score, expert_info))
|
|
|
+ unique_experts.add(expert_info.uid)
|
|
|
|
|
|
- best_experts = [
|
|
|
- RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
|
|
|
- for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
|
|
|
- ]
|
|
|
- return best_experts
|
|
|
+ return [expert_info for _, expert_info in sorted(best_experts_heap, reverse=True)]
|
|
|
|
|
|
@staticmethod
|
|
|
def _iterate_matching_experts(
|
|
|
- beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
|
|
|
- ) -> Iterator[Tuple[Score, UidEndpoint]]:
|
|
|
+ beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]], grid_scores: Sequence[Sequence[float]]
|
|
|
+ ) -> Iterator[Tuple[Score, ExpertInfo]]:
|
|
|
"""iterate over all exemplar experts attached to current beam"""
|
|
|
for score, prefix, suffixes in beam:
|
|
|
for next_coord, match in suffixes.items():
|
|
@@ -399,7 +390,7 @@ class MoEBeamSearcher:
|
|
|
beam_size: int,
|
|
|
negative_caching: bool,
|
|
|
num_workers: Optional[int],
|
|
|
- ) -> Sequence[Sequence[RemoteExpertInfo]]:
|
|
|
+ ) -> Sequence[Sequence[ExpertInfo]]:
|
|
|
batch_grid_scores = [
|
|
|
[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
|
|
|
]
|