beam_search.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import asyncio
  2. import heapq
  3. from collections import deque
  4. from functools import partial
  5. from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
  6. from hivemind.dht import DHT, DHTNode
  7. from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
  8. from hivemind.moe.expert_uid import (
  9. FLAT_EXPERT,
  10. PREFIX_PATTERN,
  11. UID_DELIMITER,
  12. Coordinate,
  13. ExpertInfo,
  14. ExpertPrefix,
  15. ExpertUID,
  16. Score,
  17. is_valid_prefix,
  18. is_valid_uid,
  19. )
  20. from hivemind.p2p import PeerID
  21. from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_dht_time, get_logger
  22. logger = get_logger(__name__)
  23. class MoEBeamSearcher:
  24. """
  25. Utility class that uses DHT to find most suitable experts for RemoteMixtureOfExperts.
  26. Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
  27. An expert identifier consists of:
  28. * optional prefix that determines expert role, experiment name, etc.
  29. * one or more integers that determine that expert's position in an N-dimensional grid
  30. A hivemind.moe.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
  31. When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
  32. For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
  33. ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
  34. In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
  35. (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
  36. RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
  37. For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
  38. This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
  39. However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
  40. In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
  41. It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
  42. After selecting k best indices along first dimension, MoE moves to the second dimension.
  43. It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
  44. This beam search explores one additional dimension per step and finds k best experts from across the DHT
  45. in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
  46. :param dht: a running DHT daemon that is used for beam search AND local caching
  47. :param uid_prefix: search for experts whose uids start with this prefix
  48. :param grid_size: dimensions that form expert uid (see above)
  49. :param num_workers: number of concurrent DHT coroutines per beam search
  50. :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
  51. result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
  52. 1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
  53. a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
  54. result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
  55. 2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
  56. search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
  57. experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
  58. visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
  59. 3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
  60. performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
  61. indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
  62. be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
  63. Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
  64. """
  65. def __init__(
  66. self,
  67. dht: DHT,
  68. uid_prefix: ExpertPrefix,
  69. grid_size: Sequence[int],
  70. num_workers: Optional[int] = None,
  71. negative_caching: bool = True,
  72. cache_expiration: DHTExpiration = 300,
  73. **kwargs,
  74. ):
  75. if not uid_prefix.endswith(UID_DELIMITER):
  76. uid_prefix += UID_DELIMITER
  77. logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
  78. assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
  79. self.dht = dht
  80. self.uid_prefix, self.grid_size = uid_prefix, grid_size
  81. self.total_grid_size = sum(grid_size)
  82. self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
  83. self.num_workers, self.dht_kwargs = num_workers, kwargs
  84. def get_initial_beam(
  85. self, scores: Sequence[float], beam_size: int, return_future: bool = False
  86. ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
  87. """
  88. :param scores: prefer suffix coordinates that have highest scores
  89. :param beam_size: select this many active suffixes with highest scores
  90. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  91. :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
  92. """
  93. return self.dht.run_coroutine(
  94. partial(
  95. self._get_initial_beam,
  96. prefix=self.uid_prefix,
  97. beam_size=beam_size,
  98. scores=tuple(scores),
  99. negative_caching=self.negative_caching,
  100. cache_expiration=self.cache_expiration,
  101. num_workers=self.num_workers,
  102. ),
  103. return_future,
  104. )
  105. @staticmethod
  106. async def _get_initial_beam(
  107. dht: DHT,
  108. node: DHTNode,
  109. prefix: ExpertPrefix,
  110. beam_size: int,
  111. scores: Tuple[float, ...],
  112. negative_caching: bool,
  113. cache_expiration: DHTExpiration,
  114. num_workers: Optional[int] = None,
  115. ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
  116. num_workers = num_workers or dht.num_workers or beam_size
  117. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = []
  118. unattempted_indices: List[Coordinate] = sorted(
  119. range(len(scores)), key=scores.__getitem__
  120. ) # from worst to best
  121. pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
  122. while len(beam) < beam_size and (unattempted_indices or pending_tasks):
  123. # dispatch additional tasks
  124. while unattempted_indices and len(pending_tasks) < num_workers:
  125. next_index = unattempted_indices.pop() # note: this is best unattempted index because of sort order
  126. next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
  127. pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
  128. # await the next best prefix to be fetched
  129. pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
  130. try:
  131. maybe_prefix_data = await pending_task
  132. if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
  133. successors = MoEBeamSearcher._select_valid_entries(maybe_prefix_data)
  134. if successors:
  135. beam.append((scores[pending_best_index], pending_best_prefix, successors))
  136. elif maybe_prefix_data is None and negative_caching:
  137. logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
  138. asyncio.create_task(
  139. node.store(
  140. pending_best_prefix,
  141. subkey=-1,
  142. value=None,
  143. expiration_time=get_dht_time() + cache_expiration,
  144. )
  145. )
  146. except asyncio.CancelledError:
  147. for _, pending_task in pending_tasks:
  148. pending_task.cancel()
  149. raise
  150. return beam
  151. def get_active_successors(
  152. self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
  153. ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
  154. """
  155. :param prefixes: a list of prefix for which to find active successor uids
  156. :param grid_size: if specified, only return successors if they are in range [0, grid_size)
  157. :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
  158. :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
  159. :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
  160. """
  161. assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
  162. for prefix in prefixes:
  163. assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
  164. return self.dht.run_coroutine(
  165. partial(
  166. self._get_active_successors,
  167. prefixes=list(prefixes),
  168. grid_size=grid_size,
  169. negative_caching=self.negative_caching,
  170. cache_expiration=self.cache_expiration,
  171. num_workers=self.num_workers,
  172. ),
  173. return_future=return_future,
  174. )
  175. @staticmethod
  176. def _select_valid_entries(entry: ValueWithExpiration, grid_size: Optional[int] = None):
  177. if not isinstance(entry, ValueWithExpiration) or not isinstance(entry.value, dict):
  178. return {}
  179. return {
  180. coord: ExpertInfo(uid=match.value[0], peer_id=PeerID.from_base58(match.value[1]))
  181. for coord, match in entry.value.items()
  182. if isinstance(coord, Coordinate)
  183. and (grid_size is None or 0 <= coord < grid_size)
  184. and isinstance(match, ValueWithExpiration)
  185. and isinstance(match.value, tuple)
  186. and len(match.value) == 2
  187. and is_valid_uid(match.value[0])
  188. and isinstance(match.value[1], str)
  189. }
  190. @staticmethod
  191. async def _get_active_successors(
  192. dht: DHT,
  193. node: DHTNode,
  194. prefixes: List[ExpertPrefix],
  195. grid_size: Optional[int],
  196. negative_caching: bool,
  197. cache_expiration: DHTExpiration,
  198. num_workers: Optional[int] = None,
  199. ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
  200. grid_size = grid_size or float("inf")
  201. num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
  202. dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
  203. successors: Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]] = {}
  204. for prefix, found in dht_responses.items():
  205. successors[prefix] = MoEBeamSearcher._select_valid_entries(found, grid_size)
  206. if not successors[prefix] and negative_caching:
  207. logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
  208. asyncio.create_task(
  209. node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
  210. )
  211. return successors
  212. def find_best_experts(
  213. self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
  214. ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
  215. """
  216. Find and return :beam_size: active experts with highest scores, use both local cache and DHT
  217. :param grid_scores: scores predicted for each dimension in the grid
  218. :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
  219. :param beam_size: how many best experts should beam search return
  220. After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
  221. Please note that any queries that fall outside the budget will still be performed in background and cached
  222. for subsequent iterations as long as DHTNode.cache_locally is True
  223. :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
  224. :returns: a list that contains *up to* k_best RemoteExpert instances
  225. """
  226. assert len(grid_scores) == len(self.grid_size) and beam_size > 0
  227. result = self.dht.run_coroutine(
  228. partial(
  229. self._find_best_experts,
  230. prefix=self.uid_prefix,
  231. beam_size=beam_size,
  232. grid_scores=list(grid_scores),
  233. negative_caching=self.negative_caching,
  234. cache_expiration=self.cache_expiration,
  235. num_workers=self.num_workers,
  236. ),
  237. return_future,
  238. )
  239. return create_remote_experts(result, self.dht, return_future)
  240. @classmethod
  241. async def _find_best_experts(
  242. cls,
  243. dht: DHT,
  244. node: DHTNode,
  245. prefix: str,
  246. grid_scores: List[Tuple[float]],
  247. beam_size: int,
  248. negative_caching: bool,
  249. cache_expiration: DHTExpiration,
  250. num_workers: Optional[int] = None,
  251. ) -> List[ExpertInfo]:
  252. num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
  253. # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
  254. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = await cls._get_initial_beam(
  255. dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
  256. )
  257. best_experts_heap: List[Tuple[Score, ExpertInfo]] = [] # max-heap of expert infos ordered by scores
  258. unique_experts: Set[ExpertUID] = set()
  259. for dim_index in range(1, len(grid_scores) - 1):
  260. for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
  261. if expert_info.uid not in unique_experts:
  262. push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
  263. push_and_maybe_pop(best_experts_heap, (score, expert_info))
  264. unique_experts.add(expert_info.uid)
  265. # form new beam using successors from the current beam
  266. dim_scores = grid_scores[dim_index]
  267. best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(
  268. beam_size,
  269. (
  270. (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
  271. for prefix_score, prefix, suffixes in beam
  272. for next_coord in suffixes.keys()
  273. if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
  274. ),
  275. )
  276. _, best_uid_prefixes = zip(*best_active_pairs)
  277. # search DHT for next step suffixes
  278. successors = await cls._get_active_successors(
  279. dht,
  280. node,
  281. best_uid_prefixes,
  282. grid_size=None,
  283. negative_caching=negative_caching,
  284. cache_expiration=cache_expiration,
  285. num_workers=num_workers,
  286. )
  287. beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
  288. if not beam:
  289. logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
  290. break
  291. # add best experts from the final beam
  292. for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
  293. if expert_info.uid not in unique_experts:
  294. push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
  295. push_and_maybe_pop(best_experts_heap, (score, expert_info))
  296. unique_experts.add(expert_info.uid)
  297. return [expert_info for _, expert_info in sorted(best_experts_heap, reverse=True)]
  298. @staticmethod
  299. def _iterate_matching_experts(
  300. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]], grid_scores: Sequence[Sequence[float]]
  301. ) -> Iterator[Tuple[Score, ExpertInfo]]:
  302. """iterate over all exemplar experts attached to current beam"""
  303. for score, prefix, suffixes in beam:
  304. for next_coord, match in suffixes.items():
  305. if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
  306. yield score, match
  307. elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
  308. expert_coords = match.uid.split(UID_DELIMITER)[1:]
  309. if all(
  310. coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
  311. for i, coord in enumerate(expert_coords)
  312. ):
  313. expert_score = sum(
  314. scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords))
  315. )
  316. yield expert_score, match
  317. else:
  318. logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
  319. else:
  320. logger.warning(f"Found incompatible expert UID: {match.uid}")
  321. def batch_find_best_experts(
  322. self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
  323. ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
  324. """
  325. Find and return :beam_size: active experts with highest scores, use both local cache and DHT
  326. :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
  327. :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
  328. :param beam_size: how many best experts should beam search return
  329. After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
  330. Please note that any queries that fall outside the budget will still be performed in background and cached
  331. for subsequent iterations as long as DHTNode.cache_locally is True
  332. :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
  333. :returns: a list that contains *up to* k_best RemoteExpert instances
  334. """
  335. result = self.dht.run_coroutine(
  336. partial(
  337. self._batch_find_best_experts,
  338. prefix=self.uid_prefix,
  339. batch_grid_scores=batch_grid_scores,
  340. beam_size=beam_size,
  341. negative_caching=self.negative_caching,
  342. num_workers=self.num_workers,
  343. ),
  344. return_future,
  345. )
  346. return batch_create_remote_experts(result, self.dht, return_future)
  347. @classmethod
  348. async def _batch_find_best_experts(
  349. cls,
  350. dht: DHT,
  351. node: DHTNode,
  352. prefix: str,
  353. batch_grid_scores: Sequence[Sequence[Tuple[float]]],
  354. beam_size: int,
  355. negative_caching: bool,
  356. num_workers: Optional[int],
  357. ) -> Sequence[Sequence[ExpertInfo]]:
  358. batch_grid_scores = [
  359. [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
  360. ]
  361. coros = [
  362. cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
  363. for grid_scores in batch_grid_scores
  364. ]
  365. return await asyncio.gather(*coros)