|
@@ -4,7 +4,7 @@ from collections import deque
|
|
|
from functools import partial
|
|
|
from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
|
|
|
|
|
|
-from hivemind.dht import DHT, DHTNode
|
|
|
+from hivemind.dht import DHT, DHTNode, DHTExpiration
|
|
|
from hivemind.client.expert import RemoteExpert
|
|
|
from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
|
|
|
PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
|
|
@@ -22,7 +22,7 @@ class MoEBeamSearcher:
|
|
|
* optional prefix that determines expert role, experiment name, etc.
|
|
|
* one or more integers that determine that expert's position in an N-dimensional grid
|
|
|
|
|
|
- A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
|
|
|
+ A hivemind.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
|
|
|
When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
|
|
|
For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
|
|
|
``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
|
|
@@ -63,8 +63,8 @@ class MoEBeamSearcher:
|
|
|
Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Tuple[int, ...],
|
|
|
- num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
|
|
|
+ def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Sequence[int], num_workers: Optional[int] = None,
|
|
|
+ negative_caching: bool = True, cache_expiration: DHTExpiration = 300, **kwargs):
|
|
|
if not uid_prefix.endswith(UID_DELIMITER):
|
|
|
uid_prefix += UID_DELIMITER
|
|
|
logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
|
|
@@ -72,7 +72,8 @@ class MoEBeamSearcher:
|
|
|
self.dht = dht
|
|
|
self.uid_prefix, self.grid_size = uid_prefix, grid_size
|
|
|
self.total_grid_size = sum(grid_size)
|
|
|
- self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
|
|
|
+ self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
|
|
|
+ self.num_workers, self.dht_kwargs = num_workers, kwargs
|
|
|
|
|
|
def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
|
|
|
) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
@@ -84,12 +85,14 @@ class MoEBeamSearcher:
|
|
|
"""
|
|
|
return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
|
|
|
scores=tuple(scores), negative_caching=self.negative_caching,
|
|
|
- num_workers=self.num_workers), return_future)
|
|
|
+ cache_expiration=self.cache_expiration, num_workers=self.num_workers),
|
|
|
+ return_future)
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _get_initial_beam(dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int,
|
|
|
- scores: Tuple[float, ...], negative_caching: bool, num_workers: Optional[int] = None
|
|
|
- ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
|
+ async def _get_initial_beam(
|
|
|
+ dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
|
|
|
+ negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None,
|
|
|
+ ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
|
|
|
num_workers = num_workers or dht.max_workers or beam_size
|
|
|
beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
|
|
|
unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__) # from worst to best
|
|
@@ -115,7 +118,7 @@ class MoEBeamSearcher:
|
|
|
elif maybe_prefix_data is None and negative_caching:
|
|
|
logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
|
|
|
asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
|
|
|
- expiration_time=get_dht_time() + dht.default_expiration))
|
|
|
+ expiration_time=get_dht_time() + cache_expiration))
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
for _, pending_task in pending_tasks:
|
|
@@ -137,12 +140,14 @@ class MoEBeamSearcher:
|
|
|
assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
|
|
|
return self.dht.run_coroutine(partial(
|
|
|
self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
|
|
|
- negative_caching=self.negative_caching, num_workers=self.num_workers), return_future=return_future)
|
|
|
+ negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
|
|
|
+ num_workers=self.num_workers), return_future=return_future)
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _get_active_successors(dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
|
|
|
- negative_caching: bool, num_workers: Optional[int] = None
|
|
|
- ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
|
|
|
+ async def _get_active_successors(
|
|
|
+ dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
|
|
|
+ negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
|
|
|
+ ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
|
|
|
grid_size = grid_size or float('inf')
|
|
|
num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
|
|
|
dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
|
|
@@ -157,7 +162,7 @@ class MoEBeamSearcher:
|
|
|
if found is None and negative_caching:
|
|
|
logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
|
|
|
asyncio.create_task(node.store(prefix, subkey=-1, value=None,
|
|
|
- expiration_time=get_dht_time() + dht.default_expiration))
|
|
|
+ expiration_time=get_dht_time() + cache_expiration))
|
|
|
return successors
|
|
|
|
|
|
def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
|
|
@@ -176,14 +181,16 @@ class MoEBeamSearcher:
|
|
|
:returns: a list that contains *up to* k_best RemoteExpert instances
|
|
|
"""
|
|
|
assert len(grid_scores) == len(self.grid_size) and beam_size > 0
|
|
|
- return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
|
|
|
- grid_scores=list(grid_scores), negative_caching=self.negative_caching,
|
|
|
- num_workers=self.num_workers), return_future)
|
|
|
+ return self.dht.run_coroutine(partial(
|
|
|
+ self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size, grid_scores=list(grid_scores),
|
|
|
+ negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
|
|
|
+ num_workers=self.num_workers), return_future)
|
|
|
|
|
|
@classmethod
|
|
|
async def _find_best_experts(
|
|
|
cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
|
|
|
- negative_caching: bool, num_workers: Optional[int] = None) -> List[RemoteExpert]:
|
|
|
+ negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
|
|
|
+ ) -> List[RemoteExpert]:
|
|
|
num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
|
|
|
|
|
|
# form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
|
|
@@ -209,8 +216,9 @@ class MoEBeamSearcher:
|
|
|
_, best_uid_prefixes = zip(*best_active_pairs)
|
|
|
|
|
|
# search DHT for next step suffixes
|
|
|
- successors = await cls._get_active_successors(dht, node, best_uid_prefixes, grid_size=None,
|
|
|
- negative_caching=negative_caching, num_workers=num_workers)
|
|
|
+ successors = await cls._get_active_successors(
|
|
|
+ dht, node, best_uid_prefixes, grid_size=None, negative_caching=negative_caching,
|
|
|
+ cache_expiration=cache_expiration, num_workers=num_workers)
|
|
|
beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
|
|
|
if not beam:
|
|
|
logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
|