beam_search.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import asyncio
  2. import heapq
  3. from collections import deque
  4. from functools import partial
  5. from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
  6. from hivemind.dht import DHT, DHTExpiration, DHTNode
  7. from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
  8. from hivemind.moe.server.expert_uid import (
  9. FLAT_EXPERT,
  10. PREFIX_PATTERN,
  11. UID_DELIMITER,
  12. Coordinate,
  13. ExpertPrefix,
  14. ExpertUID,
  15. Score,
  16. UidEndpoint,
  17. is_valid_prefix,
  18. )
  19. from hivemind.p2p import PeerInfo
  20. from hivemind.utils import get_dht_time, get_logger, LazyFutureCaller, LazyValue
  21. logger = get_logger(__name__)
  22. class MoEBeamSearcher:
  23. """
  24. Utility class that uses DHT to find most suitable experts for RemoteMixtureOfExperts.
  25. Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
  26. An expert identifier consists of:
  27. * optional prefix that determines expert role, experiment name, etc.
  28. * one or more integers that determine that expert's position in an N-dimensional grid
  29. A hivemind.moe.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
  30. When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
  31. For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
  32. ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
  33. In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
  34. (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
  35. RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
  36. For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
  37. This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
  38. However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
  39. In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
  40. It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
  41. After selecting k best indices along first dimension, MoE moves to the second dimension.
  42. It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
  43. This beam search explores one additional dimension per step and finds k best experts from across the DHT
  44. in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
  45. :param dht: a running DHT daemon that is used for beam search AND local caching
  46. :param uid_prefix: search for experts whose uids start with this prefix
  47. :param grid_size: dimensions that form expert uid (see above)
  48. :param num_workers: number of concurrent DHT coroutines per beam search
  49. :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
  50. result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
  51. 1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
  52. a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
  53. result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
  54. 2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
  55. search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
  56. experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
  57. visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
  58. 3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
  59. performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
  60. indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
  61. be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
  62. Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
  63. """
  64. def __init__(
  65. self,
  66. dht: DHT,
  67. uid_prefix: ExpertPrefix,
  68. grid_size: Sequence[int],
  69. num_workers: Optional[int] = None,
  70. negative_caching: bool = True,
  71. cache_expiration: DHTExpiration = 300,
  72. **kwargs,
  73. ):
  74. if not uid_prefix.endswith(UID_DELIMITER):
  75. uid_prefix += UID_DELIMITER
  76. logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
  77. assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
  78. self.dht = dht
  79. self.uid_prefix, self.grid_size = uid_prefix, grid_size
  80. self.total_grid_size = sum(grid_size)
  81. self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
  82. self.num_workers, self.dht_kwargs = num_workers, kwargs
  83. def get_initial_beam(
  84. self, scores: Sequence[float], beam_size: int, return_future: bool = False
  85. ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
  86. """
  87. :param scores: prefer suffix coordinates that have highest scores
  88. :param beam_size: select this many active suffixes with highest scores
  89. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  90. :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
  91. """
  92. return self.dht.run_coroutine(
  93. partial(
  94. self._get_initial_beam,
  95. prefix=self.uid_prefix,
  96. beam_size=beam_size,
  97. scores=tuple(scores),
  98. negative_caching=self.negative_caching,
  99. cache_expiration=self.cache_expiration,
  100. num_workers=self.num_workers,
  101. ),
  102. return_future,
  103. )
  104. @staticmethod
  105. async def _get_initial_beam(
  106. dht: DHT,
  107. node: DHTNode,
  108. prefix: ExpertPrefix,
  109. beam_size: int,
  110. scores: Tuple[float, ...],
  111. negative_caching: bool,
  112. cache_expiration: DHTExpiration,
  113. num_workers: Optional[int] = None,
  114. ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
  115. num_workers = num_workers or dht.num_workers or beam_size
  116. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
  117. unattempted_indices: List[Coordinate] = sorted(
  118. range(len(scores)), key=scores.__getitem__
  119. ) # from worst to best
  120. pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
  121. while len(beam) < beam_size and (unattempted_indices or pending_tasks):
  122. # dispatch additional tasks
  123. while unattempted_indices and len(pending_tasks) < num_workers:
  124. next_index = unattempted_indices.pop() # note: this is best unattempted index because of sort order
  125. next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
  126. pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
  127. # await the next best prefix to be fetched
  128. pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
  129. try:
  130. maybe_prefix_data = await pending_task
  131. if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
  132. successors = {
  133. coord: UidEndpoint(*match.value)
  134. for coord, match in maybe_prefix_data.value.items()
  135. if isinstance(coord, Coordinate)
  136. and isinstance(getattr(match, "value", None), list)
  137. and len(match.value) == 2
  138. }
  139. if successors:
  140. beam.append((scores[pending_best_index], pending_best_prefix, successors))
  141. elif maybe_prefix_data is None and negative_caching:
  142. logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
  143. asyncio.create_task(
  144. node.store(
  145. pending_best_prefix,
  146. subkey=-1,
  147. value=None,
  148. expiration_time=get_dht_time() + cache_expiration,
  149. )
  150. )
  151. except asyncio.CancelledError:
  152. for _, pending_task in pending_tasks:
  153. pending_task.cancel()
  154. raise
  155. return beam
  156. def get_active_successors(
  157. self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
  158. ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
  159. """
  160. :param prefixes: a list of prefix for which to find active successor uids
  161. :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
  162. :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
  163. :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
  164. :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
  165. """
  166. assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
  167. for prefix in prefixes:
  168. assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
  169. return self.dht.run_coroutine(
  170. partial(
  171. self._get_active_successors,
  172. prefixes=list(prefixes),
  173. grid_size=grid_size,
  174. negative_caching=self.negative_caching,
  175. cache_expiration=self.cache_expiration,
  176. num_workers=self.num_workers,
  177. ),
  178. return_future=return_future,
  179. )
  180. @staticmethod
  181. async def _get_active_successors(
  182. dht: DHT,
  183. node: DHTNode,
  184. prefixes: List[ExpertPrefix],
  185. grid_size: Optional[int],
  186. negative_caching: bool,
  187. cache_expiration: DHTExpiration,
  188. num_workers: Optional[int] = None,
  189. ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
  190. grid_size = grid_size or float("inf")
  191. num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
  192. dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
  193. successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
  194. for prefix, found in dht_responses.items():
  195. if found and isinstance(found.value, dict):
  196. successors[prefix] = {
  197. coord: UidEndpoint(*match.value)
  198. for coord, match in found.value.items()
  199. if isinstance(coord, Coordinate)
  200. and 0 <= coord < grid_size
  201. and isinstance(getattr(match, "value", None), list)
  202. and len(match.value) == 2
  203. }
  204. else:
  205. successors[prefix] = {}
  206. if found is None and negative_caching:
  207. logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
  208. asyncio.create_task(
  209. node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
  210. )
  211. return successors
  212. def find_best_experts(
  213. self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
  214. ) -> Union[List[RemoteExpert], LazyFutureCaller]:
  215. """
  216. Find and return :beam_size: active experts with highest scores, use both local cache and DHT
  217. :param grid_scores: scores predicted for each dimension in the grid
  218. :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
  219. :param beam_size: how many best experts should beam search return
  220. After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
  221. Please note that any queries that fall outside the budget will still be performed in background and cached
  222. for subsequent iterations as long as DHTNode.cache_locally is True
  223. :param num_workers: use up to this many concurrent workers to search DHT
  224. :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
  225. :returns: a list that contains *up to* k_best RemoteExpert instances
  226. """
  227. assert len(grid_scores) == len(self.grid_size) and beam_size > 0
  228. result = self.dht.run_coroutine(
  229. partial(
  230. self._find_best_experts,
  231. prefix=self.uid_prefix,
  232. beam_size=beam_size,
  233. grid_scores=list(grid_scores),
  234. negative_caching=self.negative_caching,
  235. cache_expiration=self.cache_expiration,
  236. num_workers=self.num_workers,
  237. ),
  238. return_future,
  239. )
  240. p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
  241. if return_future:
  242. return LazyFutureCaller(
  243. result,
  244. lambda lst: [l.get(p2p=p2p) for l in lst]
  245. )
  246. return [r.get(p2p=p2p) for r in result]
  247. @classmethod
  248. async def _find_best_experts(
  249. cls,
  250. dht: DHT,
  251. node: DHTNode,
  252. prefix: str,
  253. grid_scores: List[Tuple[float]],
  254. beam_size: int,
  255. negative_caching: bool,
  256. cache_expiration: DHTExpiration,
  257. num_workers: Optional[int] = None,
  258. ) -> List[LazyValue[RemoteExpert]]:
  259. num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
  260. # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
  261. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
  262. dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
  263. )
  264. best_experts_heap: List[Tuple[Score, UidEndpoint]] = [] # max-heap of expert uids/endpoints ordered by scores
  265. unique_experts: Set[ExpertUID] = set()
  266. for dim_index in range(1, len(grid_scores) - 1):
  267. for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
  268. if uid_endpoint.uid not in unique_experts:
  269. push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
  270. push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
  271. unique_experts.add(uid_endpoint.uid)
  272. # form new beam using successors from the current beam
  273. dim_scores = grid_scores[dim_index]
  274. best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(
  275. beam_size,
  276. (
  277. (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
  278. for prefix_score, prefix, suffixes in beam
  279. for next_coord in suffixes.keys()
  280. if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
  281. ),
  282. )
  283. _, best_uid_prefixes = zip(*best_active_pairs)
  284. # search DHT for next step suffixes
  285. successors = await cls._get_active_successors(
  286. dht,
  287. node,
  288. best_uid_prefixes,
  289. grid_size=None,
  290. negative_caching=negative_caching,
  291. cache_expiration=cache_expiration,
  292. num_workers=num_workers,
  293. )
  294. beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
  295. if not beam:
  296. logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
  297. break
  298. # add best experts from the final beam
  299. for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
  300. if uid_endpoint.uid not in unique_experts:
  301. push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
  302. push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
  303. unique_experts.add(uid_endpoint.uid)
  304. best_experts = [
  305. LazyValue(init=partial(
  306. RemoteExpert,
  307. uid=uid_endpoint.uid,
  308. server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
  309. ))
  310. for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
  311. ]
  312. return best_experts
  313. @staticmethod
  314. def _iterate_matching_experts(
  315. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
  316. ) -> Iterator[Tuple[Score, UidEndpoint]]:
  317. """iterate over all exemplar experts attached to current beam"""
  318. for score, prefix, suffixes in beam:
  319. for next_coord, match in suffixes.items():
  320. if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
  321. yield score, match
  322. elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
  323. expert_coords = match.uid.split(UID_DELIMITER)[1:]
  324. if all(
  325. coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
  326. for i, coord in enumerate(expert_coords)
  327. ):
  328. expert_score = sum(
  329. scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords))
  330. )
  331. yield expert_score, match
  332. else:
  333. logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
  334. else:
  335. logger.warning(f"Found incompatible expert UID: {match.uid}")
  336. def batch_find_best_experts(
  337. self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
  338. ) -> Union[List[List[RemoteExpert]], LazyFutureCaller]:
  339. """
  340. Find and return :beam_size: active experts with highest scores, use both local cache and DHT
  341. :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
  342. :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
  343. :param beam_size: how many best experts should beam search return
  344. After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
  345. Please note that any queries that fall outside the budget will still be performed in background and cached
  346. for subsequent iterations as long as DHTNode.cache_locally is True
  347. :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
  348. :returns: a list that contains *up to* k_best RemoteExpert instances
  349. """
  350. result = self.dht.run_coroutine(
  351. partial(
  352. self._batch_find_best_experts,
  353. prefix=self.uid_prefix,
  354. batch_grid_scores=batch_grid_scores,
  355. beam_size=beam_size,
  356. negative_caching=self.negative_caching,
  357. num_workers=self.num_workers,
  358. ),
  359. return_future,
  360. )
  361. p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
  362. if return_future:
  363. return LazyFutureCaller(result, lambda res: [[e.get(p2p=p2p) for e in exps] for exps in res])
  364. return [[e.get(p2p=p2p) for e in exps] for exps in result]
  365. @classmethod
  366. async def _batch_find_best_experts(
  367. cls,
  368. dht: DHT,
  369. node: DHTNode,
  370. prefix: str,
  371. batch_grid_scores: Sequence[Sequence[Tuple[float]]],
  372. beam_size: int,
  373. negative_caching: bool,
  374. num_workers: Optional[int],
  375. ) -> Sequence[Sequence[LazyValue[RemoteExpert]]]:
  376. batch_grid_scores = [
  377. [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
  378. ]
  379. coros = [
  380. cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
  381. for grid_scores in batch_grid_scores
  382. ]
  383. return await asyncio.gather(*coros)