switch_moe.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from __future__ import annotations
  2. from typing import List, Tuple
  3. import grpc
  4. import torch
  5. from hivemind.moe.client.expert import DUMMY, RemoteExpert
  6. from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
  7. from hivemind.moe.server.expert_uid import UID_DELIMITER
  8. from hivemind.utils import nested_flatten, nested_pack
  9. from hivemind.utils.logging import get_logger
  10. logger = get_logger(__name__)
  11. class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
  12. """
  13. A module implementing Switch Transformers [1] Mixture-of-Experts inference with remote experts.
  14. [1] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.
  15. William Fedus, Barret Zoph, Noam Shazeer. https://arxiv.org/abs/2101.03961
  16. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  17. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  18. the missing experts
  19. :param in_features: common input size for experts and gating function
  20. :param grid_size: dimensions that form expert uid (see below)
  21. :param uid_prefix: common prefix for all expert uids (must end with '.')
  22. :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  23. :param dht: a DHT instance used to search for best experts
  24. :param k_best: average this many highest-scoring experts to compute activations
  25. :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
  26. :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
  27. Any expert that didn't manage to return output after that delay is considered unavailable
  28. :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
  29. :param allow_zero_outputs: whether to return just the input if no experts respond on forward pass
  30. """
  31. def __init__(
  32. self,
  33. *,
  34. grid_size: Tuple[int, ...],
  35. utilization_alpha: float = 0.9,
  36. grid_dropout: float = 1.0,
  37. jitter_eps: float = 1e-2,
  38. k_best=1,
  39. k_min=0,
  40. backward_k_min=0,
  41. allow_zero_outputs=True,
  42. **kwargs,
  43. ):
  44. super().__init__(
  45. grid_size=grid_size,
  46. k_best=k_best,
  47. k_min=k_min,
  48. backward_k_min=backward_k_min,
  49. allow_zero_outputs=allow_zero_outputs,
  50. **kwargs,
  51. )
  52. initial_utilization = torch.cat(
  53. [torch.tensor([1 / dim_size for _ in range(dim_size)], dtype=torch.float) for dim_size in grid_size],
  54. )
  55. self.register_buffer("grid_utilization", initial_utilization)
  56. self.utilization_alpha = utilization_alpha
  57. self.grid_dropout = grid_dropout
  58. self.jitter_eps = jitter_eps
  59. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  60. if input.ndim != 2:
  61. input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
  62. else:
  63. input_for_gating = input
  64. # Multiplicative jitter for regularized routing
  65. jitter_noise = torch.empty_like(input_for_gating).uniform_(1 - self.jitter_eps, 1 + self.jitter_eps)
  66. input_for_gating *= jitter_noise
  67. # Compute scores, find most appropriate experts with beam search
  68. grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
  69. grid_dropout_masks = (
  70. (
  71. torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
  72. < self.grid_dropout
  73. )
  74. for dim_size in self.beam_search.grid_size
  75. )
  76. grid_scores_dropout = [
  77. torch.where(
  78. dropout_mask,
  79. grid_score,
  80. torch.full((1,), float("-inf"), device=grid_score.device, dtype=grid_score.dtype),
  81. )
  82. for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)
  83. ]
  84. grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
  85. chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
  86. [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best
  87. )
  88. if self._expert_info is None:
  89. try:
  90. self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
  91. except StopIteration:
  92. raise RuntimeError(
  93. "No responding experts found during beam search. Check that UID prefixes and "
  94. "the grid size are consistent with running Server instances."
  95. )
  96. except grpc.RpcError as e:
  97. logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
  98. expert_mask, *expert_outputs = _RemoteCallMany.apply(
  99. DUMMY,
  100. chosen_experts,
  101. self.k_min,
  102. self.backward_k_min,
  103. self.timeout_after_k_min,
  104. self.forward_timeout,
  105. self.backward_timeout,
  106. self.detect_anomalies,
  107. self.allow_zero_outputs,
  108. self.info,
  109. *nested_flatten(((input, *args), kwargs)),
  110. )
  111. # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
  112. batch_utilization = self._compute_batch_utilization(chosen_experts, expert_mask)
  113. self.grid_utilization = (
  114. self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization
  115. )
  116. # compute expert probabilities as product across grid dimensions
  117. expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
  118. masked_probs = torch.zeros((1,), device=expert_probs.device, dtype=expert_probs.dtype)
  119. expert_probs = torch.where(expert_mask, expert_probs, masked_probs)
  120. # multiply outputs by expert probabilities
  121. averaged_outputs_flat = [
  122. (expert_probs[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
  123. for tensor in expert_outputs
  124. ] # ^-- multiply by softmax weights along first 2 axes
  125. packed_outputs = nested_pack(averaged_outputs_flat, self.info["outputs_schema"])
  126. # Load balancing loss: multiply fractions of probability mass and fractions of routed examples
  127. # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
  128. balancing_loss = torch.stack(
  129. [
  130. torch.mean(dim_softmax.mean(0) * dim_utilization) * dim_size**2
  131. for dim_softmax, dim_utilization, dim_size in zip(
  132. grid_softmax, self.grid_utilization, self.beam_search.grid_size
  133. )
  134. ]
  135. ).sum()
  136. # residual connection
  137. if isinstance(packed_outputs, torch.Tensor):
  138. packed_outputs = packed_outputs + input
  139. else:
  140. packed_outputs[0] = packed_outputs[0] + input
  141. return packed_outputs, balancing_loss
  142. @torch.no_grad()
  143. def _compute_batch_utilization(self, batch_experts, expert_mask):
  144. batch_utilization = [
  145. torch.zeros((dim_size,), dtype=self.grid_utilization.dtype, device=self.grid_utilization.device)
  146. for dim_size in self.beam_search.grid_size
  147. ]
  148. # out of chosen_experts, select those for which expert_mask is True
  149. for (sample_idx, expert_idx) in expert_mask.nonzero().cpu().numpy():
  150. expert = batch_experts[sample_idx][expert_idx]
  151. expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
  152. expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
  153. for dim_index, dim_utilization in zip(expert_indices, batch_utilization):
  154. dim_utilization[dim_index] += 1
  155. return torch.cat(
  156. [torch.nn.functional.normalize(dim_utilization, p=1, dim=0) for dim_utilization in batch_utilization]
  157. )
  158. def compute_expert_scores(
  159. self, grid_probs: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]
  160. ) -> torch.Tensor:
  161. """
  162. Compute scores for each expert by multiplying grid probabilities, autograd-friendly
  163. :param grid_probs: list of torch tensors, i-th tensor contains scores for i-th grid dimension
  164. :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
  165. :returns: a tensor of scores, float32[batch_size, k]
  166. :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
  167. """
  168. expert_counts = list(map(len, batch_experts))
  169. batch_size = len(batch_experts)
  170. max_num_experts = max(expert_counts)
  171. total_num_experts = sum(expert_counts)
  172. device = grid_probs[0].device
  173. expert_index_in_batch = torch.arange(total_num_experts, device=device)
  174. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
  175. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  176. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  177. flat_experts = [expert for row in batch_experts for expert in row]
  178. grid_indices = torch.zeros([len(flat_experts), len(grid_probs)], dtype=torch.int64)
  179. for i, expert in enumerate(flat_experts):
  180. expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
  181. expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
  182. grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
  183. scores_per_dim = [
  184. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
  185. for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)
  186. ]
  187. flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
  188. scores = torch.full((batch_size, max_num_experts), fill_value=-float("inf"), device=device)
  189. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  190. return scores