moe.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import multiprocessing as mp
  2. import multiprocessing.pool
  3. from functools import partial
  4. from typing import Tuple, List, Dict, Optional
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. from .expert import RemoteExpert, _RemoteModuleCall
  10. from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
  11. from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
  12. class RemoteMixtureOfExperts(nn.Module):
  13. """
  14. A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
  15. Natively supports pytorch autograd.
  16. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  17. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  18. the missing experts
  19. :param in_features: common input size for experts and gating function
  20. :param grid_size: tesseract dimensions that form expert uid (see below)
  21. :param uid_prefix: common prefix for all expert uids
  22. expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  23. :param dht: DHTNode where the experts reside
  24. :param num_workers: number of threads for parallel dht operation
  25. :param k_best: queries this many experts with highest scores
  26. :param k_min: makes sure at least this many experts returned output
  27. :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
  28. Any expert that didn't manage to return output after that delay is considered unavailable
  29. :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
  30. :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
  31. allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
  32. allow_broadcasting=False will raise an error
  33. """
  34. def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
  35. forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
  36. uid_prefix='', expert_padding=None, allow_broadcasting=True):
  37. super().__init__()
  38. self.dht, self.grid_size = dht, grid_size
  39. self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
  40. self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
  41. self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
  42. self.allow_broadcasting = allow_broadcasting
  43. self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
  44. self._outputs_schema = None
  45. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  46. """
  47. Choose k best experts with beam search, then call chosen experts and average their outputs.
  48. :param input: a tensor of values that are used to estimate gating function, batch-first
  49. :param args: extra positional parameters that will be passed to each expert after input, batch-first
  50. :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
  51. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  52. """
  53. if self.allow_broadcasting and input.ndim != 2:
  54. # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
  55. flattened_dims = input.shape[:-1]
  56. input_flat = input.view(-1, input.shape[-1])
  57. args_flat = [tensor.view(-1, tensor.shape[len(flattened_dims):]) for tensor in args]
  58. kwargs_flat = {key: tensor.view(-1, tensor.shape[len(flattened_dims):]) for key, tensor in kwargs.items()}
  59. out_flat = self.forward(input_flat, *args_flat, **kwargs_flat)
  60. return nested_map(lambda tensor: tensor.view(flattened_dims, tensor.shape[len(flattened_dims):]), out_flat)
  61. # 1. compute scores and find most appropriate experts with beam search
  62. grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
  63. chosen_experts = self.beam_search(grid_scores, self.k_best)
  64. # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
  65. expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
  66. expert_inputs = ((input, *args), kwargs)
  67. input_schema = nested_map(lambda x: None, expert_inputs)
  68. flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
  69. batch_jobs_args = tuple(
  70. (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
  71. self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
  72. for i in range(len(input))
  73. )
  74. averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
  75. return nested_pack(averaged_outputs_flat, self.outputs_schema)
  76. def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
  77. """
  78. Find and return k best experts in the grid using (exact) beam search of the product space
  79. :param grid_scores: scores predicted for each dimension in the grid,
  80. :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
  81. :param k_best: how many of the top experts participate in the computation
  82. :param kwargs: extra keyword parameters passed to self.dht.first_k_active
  83. :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
  84. RemoteExpert instances for *up to* k_best experts
  85. """
  86. assert len(grid_scores) == len(self.grid_size)
  87. assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
  88. batch_size = len(grid_scores[0])
  89. beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size]
  90. scores = np.zeros([batch_size, 1], dtype=np.float64)
  91. delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat
  92. for dim_index, dim_scores in enumerate(grid_scores):
  93. dim_scores = check_numpy(dim_scores)
  94. assert dim_scores.shape[-1] == self.grid_size[dim_index]
  95. # create all possible successsors from current beam
  96. dim_indices = np.arange(dim_scores.shape[1]).astype(str)
  97. new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
  98. new_candidates = new_candidates.reshape([batch_size, -1])
  99. new_scores = scores[:, :, None] + dim_scores[:, None, :]
  100. new_scores = new_scores.reshape([batch_size, -1])
  101. # select k best candidates according to scores but only those that are still active
  102. new_order = np.argsort(- new_scores, axis=-1)
  103. top_alive_lookups = [
  104. run_in_background(self.dht.first_k_active, cands[order], k_best, **kwargs)
  105. for cands, order in zip(new_candidates, new_order)]
  106. batch_cand_to_score = [
  107. dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
  108. top_alive_prefixes = [result.result() for result in top_alive_lookups]
  109. top_alive_scores = [list(map(cand_to_score.get, top_cands))
  110. for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
  111. # pad up to beam size
  112. beam = np.array([row + [self.expert_padding] * (k_best - len(row))
  113. for row in top_alive_prefixes], dtype='object')
  114. scores = np.array([row + [-float('inf')] * (k_best - len(row))
  115. for row in top_alive_scores], dtype='float32')
  116. unique_experts = self.dht.get_experts(list(set(
  117. uid for row in beam for uid in row if uid != self.expert_padding)))
  118. if self._outputs_schema is None:
  119. self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
  120. unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
  121. return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
  122. def compute_expert_scores(
  123. self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
  124. """ TODO(jheuristic) docstring here """
  125. expert_counts = list(map(len, batch_experts))
  126. batch_size = len(batch_experts)
  127. max_num_experts = max(expert_counts)
  128. total_num_experts = sum(expert_counts)
  129. expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
  130. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
  131. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  132. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  133. flat_experts = [expert for row in batch_experts for expert in row]
  134. grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
  135. for i, expert in enumerate(flat_experts):
  136. expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMETER):]
  137. expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMETER)))
  138. grid_indices[i] = expert_indices
  139. scores_per_dim = [
  140. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
  141. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
  142. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  143. scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
  144. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  145. return scores
  146. @property
  147. def outputs_schema(self):
  148. if self._outputs_schema is None:
  149. # grab some expert to set ensemble output shape
  150. dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(self.grid_size, dim=-1)
  151. self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
  152. return self._outputs_schema
  153. class _RemoteMoECall(torch.autograd.Function):
  154. """
  155. Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
  156. This function that can recover from individual failures during forward and/or backward passes.
  157. For user-friendly version of this function, use RemoteMixtureOfExperts module.
  158. """
  159. @classmethod
  160. def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
  161. k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
  162. backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
  163. expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
  164. assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
  165. # 1. call experts and await results
  166. jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
  167. results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
  168. alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
  169. if not isinstance(result, BaseException)])
  170. # ^ ^ ^-- a list of indices of experts that returned outputs in time
  171. # \ \-- list of outputs of every expert that didn't die on us
  172. # \-- a list of autograd contexts, used for parallel backward
  173. # 2. compute softmax weights for alive experts and average outputs
  174. alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
  175. alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
  176. stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
  177. flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
  178. for stacked_out in stacked_alive_outputs)
  179. # 3. save individual outputs for backward pass
  180. ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
  181. ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
  182. return tuple(map(torch.Tensor.detach, flat_average_outputs))
  183. @classmethod
  184. @once_differentiable
  185. def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
  186. """ Like normal backward, but we ignore any experts that failed during backward pass """
  187. expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
  188. alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
  189. jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
  190. for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
  191. results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
  192. backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
  193. backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
  194. backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
  195. survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
  196. weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
  197. flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
  198. for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
  199. # compute grad w.r.t. logits
  200. grad_wrt_probs = sum(tuple(
  201. torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
  202. dim=tuple(range(1, stacked_avive_out.ndim)))
  203. for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
  204. ))
  205. softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
  206. grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
  207. return grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs
  208. @staticmethod
  209. def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
  210. """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
  211. flat_inputs = nested_flatten((args, kwargs))
  212. return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, *flat_inputs)
  213. @staticmethod
  214. def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
  215. backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
  216. grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
  217. return grad_inputs