beam_search.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import asyncio
  2. import heapq
  3. from collections import deque
  4. from functools import partial
  5. from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
  6. from hivemind.dht import DHT, DHTExpiration, DHTNode
  7. from hivemind.moe.client.expert import RemoteExpert
  8. from hivemind.moe.server.expert_uid import (
  9. FLAT_EXPERT,
  10. PREFIX_PATTERN,
  11. UID_DELIMITER,
  12. Coordinate,
  13. ExpertPrefix,
  14. ExpertUID,
  15. Score,
  16. UidEndpoint,
  17. is_valid_prefix,
  18. )
  19. from hivemind.utils import MPFuture, get_dht_time, get_logger
  20. logger = get_logger(__name__)
  21. class MoEBeamSearcher:
  22. """
  23. Utility class that uses DHT to find most suitable experts for RemoteMixtureOfExperts.
  24. Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
  25. An expert identifier consists of:
  26. * optional prefix that determines expert role, experiment name, etc.
  27. * one or more integers that determine that expert's position in an N-dimensional grid
  28. A hivemind.moe.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
  29. When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
  30. For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
  31. ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
  32. In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
  33. (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
  34. RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
  35. For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
  36. This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
  37. However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
  38. In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
  39. It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
  40. After selecting k best indices along first dimension, MoE moves to the second dimension.
  41. It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
  42. This beam search explores one additional dimension per step and finds k best experts from across the DHT
  43. in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
  44. :param dht: a running DHT daemon that is used for beam search AND local caching
  45. :param uid_prefix: search for experts whose uids start with this prefix
  46. :param grid_size: dimensions that form expert uid (see above)
  47. :param num_workers: number of concurrent DHT coroutines per beam search
  48. :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
  49. result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
  50. 1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
  51. a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
  52. result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
  53. 2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
  54. search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
  55. experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
  56. visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
  57. 3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
  58. performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
  59. indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
  60. be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
  61. Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
  62. """
  63. def __init__(
  64. self,
  65. dht: DHT,
  66. uid_prefix: ExpertPrefix,
  67. grid_size: Sequence[int],
  68. num_workers: Optional[int] = None,
  69. negative_caching: bool = True,
  70. cache_expiration: DHTExpiration = 300,
  71. **kwargs,
  72. ):
  73. if not uid_prefix.endswith(UID_DELIMITER):
  74. uid_prefix += UID_DELIMITER
  75. logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
  76. assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
  77. self.dht = dht
  78. self.uid_prefix, self.grid_size = uid_prefix, grid_size
  79. self.total_grid_size = sum(grid_size)
  80. self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
  81. self.num_workers, self.dht_kwargs = num_workers, kwargs
  82. def get_initial_beam(
  83. self, scores: Sequence[float], beam_size: int, return_future: bool = False
  84. ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
  85. """
  86. :param scores: prefer suffix coordinates that have highest scores
  87. :param beam_size: select this many active suffixes with highest scores
  88. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  89. :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
  90. """
  91. return self.dht.run_coroutine(
  92. partial(
  93. self._get_initial_beam,
  94. prefix=self.uid_prefix,
  95. beam_size=beam_size,
  96. scores=tuple(scores),
  97. negative_caching=self.negative_caching,
  98. cache_expiration=self.cache_expiration,
  99. num_workers=self.num_workers,
  100. ),
  101. return_future,
  102. )
  103. @staticmethod
  104. async def _get_initial_beam(
  105. dht: DHT,
  106. node: DHTNode,
  107. prefix: ExpertPrefix,
  108. beam_size: int,
  109. scores: Tuple[float, ...],
  110. negative_caching: bool,
  111. cache_expiration: DHTExpiration,
  112. num_workers: Optional[int] = None,
  113. ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
  114. num_workers = num_workers or dht.num_workers or beam_size
  115. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
  116. unattempted_indices: List[Coordinate] = sorted(
  117. range(len(scores)), key=scores.__getitem__
  118. ) # from worst to best
  119. pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
  120. while len(beam) < beam_size and (unattempted_indices or pending_tasks):
  121. # dispatch additional tasks
  122. while unattempted_indices and len(pending_tasks) < num_workers:
  123. next_index = unattempted_indices.pop() # note: this is best unattempted index because of sort order
  124. next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
  125. pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
  126. # await the next best prefix to be fetched
  127. pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
  128. try:
  129. maybe_prefix_data = await pending_task
  130. if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
  131. successors = {
  132. coord: UidEndpoint(*match.value)
  133. for coord, match in maybe_prefix_data.value.items()
  134. if isinstance(coord, Coordinate)
  135. and isinstance(getattr(match, "value", None), list)
  136. and len(match.value) == 2
  137. }
  138. if successors:
  139. beam.append((scores[pending_best_index], pending_best_prefix, successors))
  140. elif maybe_prefix_data is None and negative_caching:
  141. logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
  142. asyncio.create_task(
  143. node.store(
  144. pending_best_prefix,
  145. subkey=-1,
  146. value=None,
  147. expiration_time=get_dht_time() + cache_expiration,
  148. )
  149. )
  150. except asyncio.CancelledError:
  151. for _, pending_task in pending_tasks:
  152. pending_task.cancel()
  153. raise
  154. return beam
  155. def get_active_successors(
  156. self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
  157. ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
  158. """
  159. :param prefixes: a list of prefix for which to find active successor uids
  160. :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
  161. :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
  162. :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
  163. :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
  164. """
  165. assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
  166. for prefix in prefixes:
  167. assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
  168. return self.dht.run_coroutine(
  169. partial(
  170. self._get_active_successors,
  171. prefixes=list(prefixes),
  172. grid_size=grid_size,
  173. negative_caching=self.negative_caching,
  174. cache_expiration=self.cache_expiration,
  175. num_workers=self.num_workers,
  176. ),
  177. return_future=return_future,
  178. )
  179. @staticmethod
  180. async def _get_active_successors(
  181. dht: DHT,
  182. node: DHTNode,
  183. prefixes: List[ExpertPrefix],
  184. grid_size: Optional[int],
  185. negative_caching: bool,
  186. cache_expiration: DHTExpiration,
  187. num_workers: Optional[int] = None,
  188. ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
  189. grid_size = grid_size or float("inf")
  190. num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
  191. dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
  192. successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
  193. for prefix, found in dht_responses.items():
  194. if found and isinstance(found.value, dict):
  195. successors[prefix] = {
  196. coord: UidEndpoint(*match.value)
  197. for coord, match in found.value.items()
  198. if isinstance(coord, Coordinate)
  199. and 0 <= coord < grid_size
  200. and isinstance(getattr(match, "value", None), list)
  201. and len(match.value) == 2
  202. }
  203. else:
  204. successors[prefix] = {}
  205. if found is None and negative_caching:
  206. logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
  207. asyncio.create_task(
  208. node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
  209. )
  210. return successors
  211. def find_best_experts(
  212. self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
  213. ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
  214. """
  215. Find and return :beam_size: active experts with highest scores, use both local cache and DHT
  216. :param grid_scores: scores predicted for each dimension in the grid
  217. :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
  218. :param beam_size: how many best experts should beam search return
  219. After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
  220. Please note that any queries that fall outside the budget will still be performed in background and cached
  221. for subsequent iterations as long as DHTNode.cache_locally is True
  222. :param num_workers: use up to this many concurrent workers to search DHT
  223. :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
  224. :returns: a list that contains *up to* k_best RemoteExpert instances
  225. """
  226. assert len(grid_scores) == len(self.grid_size) and beam_size > 0
  227. return self.dht.run_coroutine(
  228. partial(
  229. self._find_best_experts,
  230. prefix=self.uid_prefix,
  231. beam_size=beam_size,
  232. grid_scores=list(grid_scores),
  233. negative_caching=self.negative_caching,
  234. cache_expiration=self.cache_expiration,
  235. num_workers=self.num_workers,
  236. ),
  237. return_future,
  238. )
  239. @classmethod
  240. async def _find_best_experts(
  241. cls,
  242. dht: DHT,
  243. node: DHTNode,
  244. prefix: str,
  245. grid_scores: List[Tuple[float]],
  246. beam_size: int,
  247. negative_caching: bool,
  248. cache_expiration: DHTExpiration,
  249. num_workers: Optional[int] = None,
  250. ) -> List[RemoteExpert]:
  251. num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
  252. # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
  253. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
  254. dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
  255. )
  256. best_experts_heap: List[Tuple[Score, UidEndpoint]] = [] # max-heap of expert uids/endpoints ordered by scores
  257. unique_experts: Set[ExpertUID] = set()
  258. for dim_index in range(1, len(grid_scores) - 1):
  259. for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
  260. if uid_endpoint.uid not in unique_experts:
  261. push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
  262. push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
  263. unique_experts.add(uid_endpoint.uid)
  264. # form new beam using successors from the current beam
  265. dim_scores = grid_scores[dim_index]
  266. best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(
  267. beam_size,
  268. (
  269. (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
  270. for prefix_score, prefix, suffixes in beam
  271. for next_coord in suffixes.keys()
  272. if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
  273. ),
  274. )
  275. _, best_uid_prefixes = zip(*best_active_pairs)
  276. # search DHT for next step suffixes
  277. successors = await cls._get_active_successors(
  278. dht,
  279. node,
  280. best_uid_prefixes,
  281. grid_size=None,
  282. negative_caching=negative_caching,
  283. cache_expiration=cache_expiration,
  284. num_workers=num_workers,
  285. )
  286. beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
  287. if not beam:
  288. logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
  289. break
  290. # add best experts from the final beam
  291. for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
  292. if uid_endpoint.uid not in unique_experts:
  293. push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
  294. push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
  295. unique_experts.add(uid_endpoint.uid)
  296. best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
  297. return best_experts
  298. @staticmethod
  299. def _iterate_matching_experts(
  300. beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
  301. ) -> Iterator[Tuple[Score, UidEndpoint]]:
  302. """iterate over all exemplar experts attached to current beam"""
  303. for score, prefix, suffixes in beam:
  304. for next_coord, match in suffixes.items():
  305. if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
  306. yield score, match
  307. elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
  308. expert_coords = match.uid.split(UID_DELIMITER)[1:]
  309. if all(
  310. coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
  311. for i, coord in enumerate(expert_coords)
  312. ):
  313. expert_score = sum(
  314. scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords))
  315. )
  316. yield expert_score, match
  317. else:
  318. logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
  319. else:
  320. logger.warning(f"Found incompatible expert UID: {match.uid}")
  321. def batch_find_best_experts(
  322. self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
  323. ) -> Union[List[List[RemoteExpert]], MPFuture]:
  324. """
  325. Find and return :beam_size: active experts with highest scores, use both local cache and DHT
  326. :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
  327. :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
  328. :param beam_size: how many best experts should beam search return
  329. After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
  330. Please note that any queries that fall outside the budget will still be performed in background and cached
  331. for subsequent iterations as long as DHTNode.cache_locally is True
  332. :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
  333. :returns: a list that contains *up to* k_best RemoteExpert instances
  334. """
  335. return self.dht.run_coroutine(
  336. partial(
  337. self._batch_find_best_experts,
  338. prefix=self.uid_prefix,
  339. batch_grid_scores=batch_grid_scores,
  340. beam_size=beam_size,
  341. negative_caching=self.negative_caching,
  342. num_workers=self.num_workers,
  343. ),
  344. return_future,
  345. )
  346. @classmethod
  347. async def _batch_find_best_experts(
  348. cls,
  349. dht: DHT,
  350. node: DHTNode,
  351. prefix: str,
  352. batch_grid_scores: Sequence[Sequence[Tuple[float]]],
  353. beam_size: int,
  354. negative_caching: bool,
  355. num_workers: Optional[int],
  356. ) -> Sequence[Sequence[RemoteExpert]]:
  357. batch_grid_scores = [
  358. [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
  359. ]
  360. coros = [
  361. cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
  362. for grid_scores in batch_grid_scores
  363. ]
  364. return await asyncio.gather(*coros)