moe.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from __future__ import annotations
  2. import time
  3. from queue import Queue, Empty
  4. from typing import Tuple, List, Optional, Dict, Any
  5. import grpc
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. import hivemind
  10. from hivemind.client.beam_search import MoEBeamSearcher
  11. from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
  12. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  13. from hivemind.server.expert_uid import UID_DELIMITER
  14. from hivemind.utils import nested_pack, nested_flatten
  15. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  16. from hivemind.utils.logging import get_logger
  17. logger = get_logger(__name__)
  18. class RemoteMixtureOfExperts(nn.Module):
  19. """
  20. A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
  21. Natively supports pytorch autograd.
  22. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  23. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  24. the missing experts
  25. :param in_features: common input size for experts and gating function
  26. :param grid_size: dimensions that form expert uid (see below)
  27. :param uid_prefix: common prefix for all expert uids (must end with '.')
  28. :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  29. :param dht: a DHT instance used to search for best experts
  30. :param k_best: average this many highest-scoring experts to compute activations
  31. :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
  32. :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
  33. :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
  34. Any expert that didn't manage to return output after that delay is considered unavailable
  35. """
  36. def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
  37. k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
  38. backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
  39. **dht_kwargs):
  40. super().__init__()
  41. self.dht = dht
  42. self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
  43. self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
  44. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  45. self.timeout_after_k_min = timeout_after_k_min
  46. self.detect_anomalies = detect_anomalies
  47. self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
  48. self._expert_info = None # expert['info'] from one of experts in the grid
  49. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  50. """
  51. Choose k best experts with beam search, then call chosen experts and average their outputs.
  52. Input tensor is averaged over all dimensions except for first and last
  53. (we assume that extra dimensions represent sequence length or image height/width)
  54. :param input: a tensor of values that are used to estimate gating function, batch-first.
  55. :param args: extra positional parameters that will be passed to each expert after input, batch-first
  56. :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
  57. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  58. """
  59. if input.ndim != 2:
  60. input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
  61. else:
  62. input_for_gating = input
  63. # 1. compute scores and find most appropriate experts with beam search
  64. grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
  65. chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
  66. [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best)
  67. if self._expert_info is None:
  68. try:
  69. self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
  70. except grpc.RpcError as e:
  71. logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
  72. expert_mask, *expert_outputs = _RemoteCallMany.apply(
  73. DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
  74. self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs)))
  75. # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
  76. expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
  77. masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
  78. expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
  79. expert_weights = torch.softmax(expert_logits, dim=1)
  80. averaged_outputs_flat = [
  81. (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
  82. for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes
  83. return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
  84. def compute_expert_scores(
  85. self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
  86. """
  87. Compute scores for each expert by adding up grid scores, autograd-friendly
  88. :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
  89. :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
  90. :returns: a tensor of scores, float32[batch_size, k]
  91. :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
  92. """
  93. expert_counts = list(map(len, batch_experts))
  94. batch_size = len(batch_experts)
  95. max_num_experts = max(expert_counts)
  96. total_num_experts = sum(expert_counts)
  97. expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
  98. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
  99. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  100. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  101. flat_experts = [expert for row in batch_experts for expert in row]
  102. grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
  103. for i, expert in enumerate(flat_experts):
  104. expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
  105. expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
  106. grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
  107. scores_per_dim = [
  108. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
  109. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
  110. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  111. scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
  112. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  113. return scores
  114. @property
  115. def info(self):
  116. if self._expert_info is None:
  117. # grab some expert to set ensemble output shape
  118. proj_device = self.proj.weight.device
  119. dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
  120. dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
  121. dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
  122. self._expert_info = dummy_experts[0].info
  123. return self._expert_info
  124. class _RemoteCallMany(torch.autograd.Function):
  125. """
  126. Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
  127. This function that can recover from individual failures during forward and/or backward pass as long as at least
  128. one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
  129. Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
  130. experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
  131. """
  132. @classmethod
  133. def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
  134. timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
  135. detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
  136. assert not torch.is_grad_enabled()
  137. num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
  138. flat_inputs_cpu = []
  139. for tensor in flat_inputs:
  140. if detect_anomalies and not tensor.isfinite().all():
  141. raise ValueError("One of inputs has nan/inf values")
  142. flat_inputs_cpu.append(tensor.cpu())
  143. flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
  144. assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
  145. # dispatch tasks to all remote experts collect responses
  146. pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
  147. for i in range(num_samples):
  148. for j, expert in enumerate(experts_per_sample[i]):
  149. input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
  150. flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
  151. stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
  152. new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
  153. pending_tasks[new_task] = (i, j)
  154. alive_grid_indices, alive_flat_outputs = cls._collect_responses(
  155. pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
  156. if len(alive_grid_indices) == 0:
  157. raise TimeoutError("Forward pass: no alive experts responded within timeout.")
  158. # assemble responses
  159. alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
  160. mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
  161. mask[alive_ii, alive_jj] = True
  162. alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
  163. # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
  164. outputs = []
  165. for response_stacked in alive_flat_outputs_stacked:
  166. output = torch.zeros(
  167. [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
  168. dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
  169. output[alive_ii, alive_jj] = response_stacked
  170. outputs.append(output.to(flat_inputs[0].device))
  171. # save individual outputs for backward pass
  172. ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
  173. ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
  174. detect_anomalies)
  175. return (mask,) + tuple(outputs)
  176. @classmethod
  177. @once_differentiable
  178. def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
  179. assert not torch.is_grad_enabled()
  180. (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
  181. detect_anomalies) = ctx._saved_non_tensors
  182. alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
  183. dummy_grad_mask, *flat_grad_outputs = raw_grads
  184. flat_grad_outputs_cpu = []
  185. for tensor in flat_grad_outputs:
  186. if detect_anomalies and not tensor.isfinite().all():
  187. raise ValueError("One of gradients has nan/inf values")
  188. flat_grad_outputs_cpu.append(tensor.cpu())
  189. num_samples, max_experts = dummy_grad_mask.shape
  190. inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
  191. grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
  192. backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
  193. # dispatch tasks to all remote experts, collect responses
  194. pending_tasks = {}
  195. for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
  196. inputs_per_expert, grad_outputs_per_expert):
  197. expert = expert_per_sample[i.item()][j.item()]
  198. stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
  199. inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
  200. tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
  201. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
  202. new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
  203. pending_tasks[new_task] = (i, j)
  204. backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
  205. pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
  206. if len(backward_survivor_indices) == 0:
  207. raise TimeoutError("Backward pass: no alive experts responded within timeout.")
  208. # assemble responses
  209. backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
  210. survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
  211. # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
  212. grad_inputs = []
  213. for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
  214. grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
  215. (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
  216. device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
  217. grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
  218. # sum gradients from each expert
  219. grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
  220. return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
  221. @staticmethod
  222. def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
  223. timeout_total: Optional[float], timeout_after_k_min: Optional[float], detect_anomalies: bool
  224. ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
  225. """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
  226. timeout_total = float('inf') if timeout_total is None else timeout_total
  227. timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
  228. num_successful_tasks = [0 for _ in range(num_samples)]
  229. pending_samples = num_samples # samples for which we have less than k_min results
  230. finished_indices, finished_outputs = [], []
  231. t_finish = time.perf_counter() + timeout_total
  232. pending_tasks = set(task_to_indices.keys())
  233. finished_tasks = Queue()
  234. try:
  235. # the algorithm below is essentially futures.as_completed, but for grpc.Future
  236. for task in pending_tasks:
  237. task.add_done_callback(finished_tasks.put)
  238. for _ in range(len(task_to_indices)):
  239. timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float('inf') else None
  240. task = finished_tasks.get(timeout=timeout)
  241. pending_tasks.discard(task)
  242. task_output = _process_dispatched_task(task, detect_anomalies)
  243. if task_output is not None:
  244. finished_indices.append(task_to_indices[task])
  245. finished_outputs.append(task_output)
  246. # count how many successes we have for each input sample
  247. sample_index = task_to_indices[task][0]
  248. num_successful_tasks[sample_index] += 1
  249. if num_successful_tasks[sample_index] == k_min:
  250. pending_samples -= 1
  251. if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min
  252. t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
  253. except Empty:
  254. pass # we reached t_finish, this is normal behavior
  255. finally:
  256. for task in pending_tasks:
  257. task.cancel()
  258. return finished_indices, finished_outputs
  259. def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
  260. if task.exception() or task.cancelled():
  261. logger.warning(f"Task {task} failed: {type(task.exception())}")
  262. return None
  263. deserialized_outputs = []
  264. for tensor in task.result().tensors:
  265. deserialized_tensor = deserialize_torch_tensor(tensor)
  266. if detect_anomalies and not deserialized_tensor.isfinite().all():
  267. logger.error(f"Task {task} failed: output tensor contains nan/inf values")
  268. return None
  269. deserialized_outputs.append(deserialized_tensor)
  270. return tuple(deserialized_outputs)