moe.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. from __future__ import annotations
  2. import asyncio
  3. import time
  4. from typing import Tuple, List, Optional, Awaitable, Set, Dict, Any
  5. import grpc.experimental.aio
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. import hivemind
  10. from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
  11. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  12. from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
  13. from hivemind.utils.logging import get_logger
  14. logger = get_logger(__name__)
  15. class RemoteMixtureOfExperts(nn.Module):
  16. """
  17. A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
  18. Natively supports pytorch autograd.
  19. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  20. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  21. the missing experts
  22. :param in_features: common input size for experts and gating function
  23. :param grid_size: hivemind dimensions that form expert uid (see below)
  24. :param uid_prefix: common prefix for all expert uids
  25. expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  26. :param dht: DHT where the experts reside
  27. :param k_best: queries this many experts with highest scores
  28. :param k_min: makes sure at least this many experts returned output
  29. :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
  30. Any expert that didn't manage to return output after that delay is considered unavailable
  31. :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
  32. allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
  33. allow_broadcasting=False will raise an error
  34. """
  35. def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, k_best: int, k_min: int = 1,
  36. forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
  37. backward_k_min: int = 1, backward_timeout: Optional[float] = None, uid_prefix='',
  38. allow_broadcasting=True, loop: asyncio.BaseEventLoop = None):
  39. super().__init__()
  40. self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
  41. self.loop = loop or asyncio.new_event_loop()
  42. # fmt:off
  43. assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
  44. "(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
  45. # fmt:on
  46. self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
  47. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  48. self.timeout_after_k_min = timeout_after_k_min
  49. self.allow_broadcasting = allow_broadcasting
  50. self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
  51. self._expert_info = None # expert['info'] from one of experts in the grid
  52. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  53. """
  54. Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all
  55. dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions)
  56. :param input: a tensor of values that are used to estimate gating function, batch-first.
  57. :param args: extra positional parameters that will be passed to each expert after input, batch-first
  58. :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
  59. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  60. """
  61. if input.ndim != 2:
  62. input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
  63. else:
  64. input_for_gating = input
  65. # 1. compute scores and find most appropriate experts with beam search
  66. grid_scores = self.proj(input_for_gating).split_with_sizes(self.grid_size, dim=-1)
  67. async def _search():
  68. coroutines = [asyncio.create_task(self.beam_search(
  69. [dim_scores[i] for dim_scores in grid_scores], self.k_best))
  70. for i in range(len(input))]
  71. return list(await asyncio.gather(*coroutines))
  72. chosen_experts: List[List[RemoteExpert]] = self.loop.run_until_complete(_search())
  73. # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
  74. expert_mask, *expert_outputs = _RemoteCallMany.apply(
  75. DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
  76. self.backward_timeout, self.loop, self.info, *nested_flatten(((input, *args), kwargs)))
  77. # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
  78. expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
  79. masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
  80. expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
  81. expert_weights = torch.softmax(expert_logits, dim=1)
  82. averaged_outputs_flat = [
  83. (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
  84. for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes
  85. return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
  86. async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
  87. """
  88. Find and return k best experts in the grid using (exact) beam search of the product space
  89. :param grid_scores: scores predicted for each dimension in the grid,
  90. :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
  91. :param k_best: how many of the top experts participate in the computation
  92. :param kwargs: extra keyword parameters passed to self.dht.first_k_active
  93. :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
  94. RemoteExpert instances for *up to* k_best experts
  95. """
  96. assert len(grid_scores) == len(self.grid_size)
  97. assert all(dim_scores.shape == (self.grid_size[dim_index],) for dim_index, dim_scores in enumerate(grid_scores))
  98. grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores]
  99. beam_experts: List[RemoteExpert] = []
  100. beam: List[str] = [self.uid_prefix]
  101. beam_scores = torch.zeros(1)
  102. for dim_index, dim_scores in enumerate(grid_scores):
  103. # create all possible successors from current beam and sort them by total score
  104. expanded_scores = beam_scores[:, None] + dim_scores[None, :]
  105. sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores))
  106. for flat_i in (-expanded_scores).flatten().argsort().numpy()]
  107. sorted_candidates = [f"{beam[row]}{self.dht.UID_DELIMITER}{col}" for row, col in sorted_indices]
  108. candidate_to_indices = dict(zip(sorted_candidates, sorted_indices))
  109. # select k best candidates according to scores but only those that are still active
  110. best_alive_prefixes: Dict[str, RemoteExpert] = await self.dht.first_k_active(
  111. uid_prefixes=sorted_candidates, k=k_best, return_future=True, **kwargs)
  112. if not best_alive_prefixes:
  113. logger.warning(f"Grid is empty: found neither of {sorted_candidates}")
  114. break
  115. beam = list(best_alive_prefixes.keys())
  116. beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
  117. beam_experts = list(best_alive_prefixes.values())
  118. if self._expert_info is None:
  119. try:
  120. self._expert_info = beam_experts[0].info
  121. except grpc.RpcError as e:
  122. logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
  123. return beam_experts
  124. def compute_expert_scores(
  125. self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
  126. """
  127. Compute scores for each expert by adding up grid scores, autograd-friendly
  128. :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
  129. :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
  130. :returns: a tensor of scores, float32[batch_size, k]
  131. :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
  132. """
  133. expert_counts = list(map(len, batch_experts))
  134. batch_size = len(batch_experts)
  135. max_num_experts = max(expert_counts)
  136. total_num_experts = sum(expert_counts)
  137. expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
  138. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
  139. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  140. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  141. flat_experts = [expert for row in batch_experts for expert in row]
  142. grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
  143. for i, expert in enumerate(flat_experts):
  144. expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
  145. expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
  146. grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
  147. scores_per_dim = [
  148. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
  149. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
  150. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  151. scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
  152. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  153. return scores
  154. @property
  155. def info(self):
  156. if self._expert_info is None:
  157. # grab some expert to set ensemble output shape
  158. proj_device = self.proj.weight.device
  159. dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
  160. dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
  161. dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
  162. self._expert_info = dummy_experts[0].info
  163. return self._expert_info
  164. class _RemoteCallMany(torch.autograd.Function):
  165. """
  166. Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
  167. This function that can recover from individual failures during forward and/or backward pass as long as at least
  168. one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
  169. Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
  170. experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
  171. """
  172. @classmethod
  173. def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
  174. timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
  175. loop: asyncio.base_events.BaseEventLoop, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
  176. assert not torch.is_grad_enabled()
  177. num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
  178. flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
  179. assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
  180. async def _forward():
  181. # dispatch tasks to all remote experts, await responses
  182. pending_tasks = {
  183. asyncio.create_task(cls._forward_one_expert((i, j), expert, info, flat_inputs_per_sample[i]))
  184. for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
  185. }
  186. alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
  187. pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
  188. # assemble responses
  189. alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
  190. mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
  191. mask[alive_ii, alive_jj] = True
  192. alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
  193. # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
  194. outputs = []
  195. for response_stacked in alive_flat_outputs_stacked:
  196. output = torch.zeros(
  197. [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
  198. dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
  199. output[alive_ii, alive_jj] = response_stacked
  200. outputs.append(output)
  201. # save individual outputs for backward pass
  202. ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
  203. ctx._saved_non_tensors = loop, info, backward_k_min, backward_timeout,\
  204. timeout_after_k_min, experts_per_sample
  205. return (mask,) + tuple(outputs)
  206. return loop.run_until_complete(_forward())
  207. @classmethod
  208. @once_differentiable
  209. def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
  210. assert not torch.is_grad_enabled()
  211. loop, info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
  212. alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
  213. dummy_grad_mask, *flat_grad_outputs = raw_grads
  214. num_samples, max_experts = dummy_grad_mask.shape
  215. inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
  216. grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
  217. async def _backward():
  218. # dispatch tasks to all remote experts, await responses
  219. pending_tasks = set()
  220. for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
  221. inputs_per_expert, grad_outputs_per_expert):
  222. pending_tasks.add(asyncio.create_task(cls._backward_one_expert(
  223. (i, j), expert_per_sample[i.item()][j.item()], info, inputs_ij, grad_outputs_ij)))
  224. backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
  225. pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
  226. # assemble responses
  227. backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices))
  228. survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
  229. # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
  230. grad_inputs = []
  231. for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
  232. grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
  233. (num_samples, max_experts, *flat_inputs[i].shape[1:]),
  234. device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
  235. grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
  236. grad_inputs.append(grad_input_per_expert.sum(dim=1)) # add up gradients from each expert
  237. return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
  238. return loop.run_until_complete(_backward())
  239. @staticmethod
  240. async def _forward_one_expert(
  241. grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any], inputs: Tuple[torch.Tensor]):
  242. stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
  243. try:
  244. outputs = await stub.forward(runtime_pb2.ExpertRequest(
  245. uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression) for tensor, proto in
  246. zip(inputs, nested_flatten(info['forward_schema']))]))
  247. return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
  248. except grpc.experimental.aio.AioRpcError as error:
  249. logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
  250. @staticmethod
  251. async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any],
  252. inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
  253. stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
  254. inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs)))
  255. backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
  256. try:
  257. grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
  258. uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression)
  259. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]))
  260. return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
  261. except grpc.experimental.aio.AioRpcError as error:
  262. logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
  263. @staticmethod
  264. async def _wait_for_responses(
  265. pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
  266. num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
  267. ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
  268. """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
  269. timeout_total = float('inf') if timeout_total is None else timeout_total
  270. timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
  271. num_successful_tasks = [0 for _ in range(num_samples)]
  272. pending_samples = num_samples # samples for which we have less than k_min results
  273. finished_indices, finished_outputs = [], []
  274. t_finish = time.perf_counter() + timeout_total
  275. while pending_tasks and time.perf_counter() <= t_finish:
  276. finished_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED,
  277. timeout=t_finish - time.perf_counter())
  278. for task in finished_tasks:
  279. if not task.result():
  280. continue
  281. task_indices, task_flat_outputs = await task
  282. finished_indices.append(task_indices)
  283. finished_outputs.append(task_flat_outputs)
  284. sample_index = task_indices[0]
  285. num_successful_tasks[sample_index] += 1
  286. if num_successful_tasks[sample_index] == k_min:
  287. pending_samples -= 1
  288. if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min
  289. t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
  290. for task in pending_tasks:
  291. task.cancel()
  292. return finished_indices, finished_outputs