moe.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. from __future__ import annotations
  2. import time
  3. from queue import Queue, Empty
  4. from typing import Tuple, List, Optional, Dict, Any
  5. import grpc
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. import hivemind
  10. from hivemind.client.beam_search import MoEBeamSearcher
  11. from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
  12. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  13. from hivemind.server.expert_uid import UID_DELIMITER
  14. from hivemind.utils import nested_pack, nested_flatten, nested_map
  15. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  16. from hivemind.utils.logging import get_logger
  17. logger = get_logger(__name__)
  18. class RemoteMixtureOfExperts(nn.Module):
  19. """
  20. A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
  21. Natively supports pytorch autograd.
  22. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  23. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  24. the missing experts
  25. :param in_features: common input size for experts and gating function
  26. :param grid_size: dimensions that form expert uid (see below)
  27. :param uid_prefix: common prefix for all expert uids (must end with '.')
  28. :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  29. :param dht: a DHT instance used to search for best experts
  30. :param k_best: average this many highest-scoring experts to compute activations
  31. :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
  32. :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
  33. Any expert that didn't manage to return output after that delay is considered unavailable
  34. :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
  35. :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
  36. """
  37. def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
  38. k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
  39. backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
  40. allow_zero_outputs: bool = False, **dht_kwargs):
  41. super().__init__()
  42. self.dht = dht
  43. self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
  44. self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
  45. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  46. self.timeout_after_k_min = timeout_after_k_min
  47. self.detect_anomalies = detect_anomalies
  48. self.allow_zero_outputs = allow_zero_outputs
  49. # jointly predict logits for all grid dimensions
  50. self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
  51. self._expert_info = None # expert['info'] from one of experts in the grid
  52. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  53. """
  54. Choose k best experts with beam search, then call chosen experts and average their outputs.
  55. Input tensor is averaged over all dimensions except for first and last
  56. (we assume that extra dimensions represent sequence length or image height/width)
  57. :param input: a tensor of values that are used to estimate gating function, batch-first.
  58. :param args: extra positional parameters that will be passed to each expert after input, batch-first
  59. :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
  60. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  61. """
  62. if input.ndim != 2:
  63. input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
  64. else:
  65. input_for_gating = input
  66. # 1. compute scores and find most appropriate experts with beam search
  67. grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
  68. chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
  69. [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best)
  70. if self._expert_info is None:
  71. try:
  72. self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
  73. except StopIteration:
  74. raise RuntimeError("No responding experts found during beam search. Check that UID prefixes and "
  75. "the grid size are consistent with running Server instances.")
  76. except grpc.RpcError as e:
  77. logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
  78. expert_mask, *expert_outputs = _RemoteCallMany.apply(
  79. DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
  80. self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
  81. *nested_flatten(((input, *args), kwargs)))
  82. # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
  83. expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
  84. masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
  85. expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
  86. expert_weights = torch.softmax(expert_logits, dim=1)
  87. averaged_outputs_flat = [
  88. (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
  89. for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes
  90. return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
  91. def compute_expert_scores(
  92. self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
  93. """
  94. Compute scores for each expert by adding up grid scores, autograd-friendly
  95. :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
  96. :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
  97. :returns: a tensor of scores, float32[batch_size, k]
  98. :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
  99. """
  100. expert_counts = list(map(len, batch_experts))
  101. batch_size = len(batch_experts)
  102. max_num_experts = max(expert_counts)
  103. total_num_experts = sum(expert_counts)
  104. device = grid_scores[0].device
  105. expert_index_in_batch = torch.arange(total_num_experts, device=device)
  106. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
  107. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  108. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  109. flat_experts = [expert for row in batch_experts for expert in row]
  110. grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
  111. for i, expert in enumerate(flat_experts):
  112. expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
  113. expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
  114. grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
  115. scores_per_dim = [
  116. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
  117. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
  118. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  119. scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
  120. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  121. return scores
  122. @property
  123. def info(self):
  124. if self._expert_info is None:
  125. # grab some expert to set ensemble output shape
  126. proj_device = self.proj.weight.device
  127. dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
  128. dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
  129. dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
  130. self._expert_info = dummy_experts[0].info
  131. return self._expert_info
  132. class _RemoteCallMany(torch.autograd.Function):
  133. """
  134. Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
  135. This function that can recover from individual failures during forward and/or backward pass as long as at least
  136. one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
  137. Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
  138. experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
  139. """
  140. @classmethod
  141. def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
  142. timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
  143. detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any],
  144. *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
  145. assert not torch.is_grad_enabled()
  146. num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
  147. flat_inputs_cpu = []
  148. for tensor in flat_inputs:
  149. if detect_anomalies and not tensor.isfinite().all():
  150. raise ValueError("One of inputs has nan/inf values")
  151. flat_inputs_cpu.append(tensor.cpu())
  152. flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
  153. assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
  154. # dispatch tasks to all remote experts collect responses
  155. pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
  156. for i in range(num_samples):
  157. for j, expert in enumerate(experts_per_sample[i]):
  158. input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
  159. flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
  160. stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
  161. new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
  162. pending_tasks[new_task] = (i, j)
  163. responded_inds, alive_flat_outputs = cls._collect_responses(
  164. pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
  165. if len(responded_inds) < k_min:
  166. raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
  167. if not isinstance(info['outputs_schema'], tuple):
  168. outputs_schema = (info['outputs_schema'],)
  169. else:
  170. outputs_schema = info['outputs_schema']
  171. outputs = nested_map(
  172. lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
  173. outputs_schema)
  174. # assemble responses
  175. if len(responded_inds) > 0 or allow_zero_outputs:
  176. batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
  177. list(zip(*responded_inds)) or ([], []))
  178. alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
  179. # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
  180. for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
  181. output[batch_inds, expert_inds] = response_stacked.to(output.device)
  182. else:
  183. raise RuntimeError('Forward pass: 0 experts responded within timeout and allow_zero_outputs is False')
  184. mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
  185. mask[batch_inds, expert_inds] = True
  186. # save individual outputs for backward pass
  187. ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
  188. ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
  189. detect_anomalies)
  190. return (mask,) + outputs
  191. @classmethod
  192. @once_differentiable
  193. def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
  194. assert not torch.is_grad_enabled()
  195. (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
  196. detect_anomalies) = ctx._saved_non_tensors
  197. alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
  198. dummy_grad_mask, *flat_grad_outputs = raw_grads
  199. flat_grad_outputs_cpu = []
  200. for tensor in flat_grad_outputs:
  201. if detect_anomalies and not tensor.isfinite().all():
  202. raise ValueError("One of gradients has nan/inf values")
  203. flat_grad_outputs_cpu.append(tensor.cpu())
  204. num_samples, max_experts = dummy_grad_mask.shape
  205. inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
  206. grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
  207. backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
  208. # dispatch tasks to all remote experts, collect responses
  209. pending_tasks = {}
  210. for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
  211. inputs_per_expert, grad_outputs_per_expert):
  212. expert = expert_per_sample[i.item()][j.item()]
  213. stub = _get_expert_stub(expert.endpoint)
  214. inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
  215. tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
  216. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
  217. new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
  218. pending_tasks[new_task] = (i, j)
  219. survivor_inds, survivor_grad_inputs = cls._collect_responses(
  220. pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
  221. if len(survivor_inds) < backward_k_min:
  222. raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
  223. # assemble responses
  224. batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, dtype=torch.long),
  225. list(zip(*survivor_inds)) or ([], []))
  226. survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
  227. # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
  228. grad_inputs = nested_map(
  229. lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
  230. list(nested_flatten(info['forward_schema'])))
  231. for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
  232. grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
  233. (num_samples, max_experts, *grad_input.shape[1:]),
  234. device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
  235. grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
  236. grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
  237. return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
  238. @staticmethod
  239. def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
  240. timeout_total: Optional[float], timeout_after_k_min: Optional[float], detect_anomalies: bool
  241. ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
  242. """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
  243. timeout_total = float('inf') if timeout_total is None else timeout_total
  244. timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
  245. num_successful_tasks = [0 for _ in range(num_samples)]
  246. pending_samples = num_samples # samples for which we have less than k_min results
  247. finished_indices, finished_outputs = [], []
  248. t_finish = time.perf_counter() + timeout_total
  249. pending_tasks = set(task_to_indices.keys())
  250. finished_tasks = Queue()
  251. try:
  252. # the algorithm below is essentially futures.as_completed, but for grpc.Future
  253. for task in pending_tasks:
  254. task.add_done_callback(finished_tasks.put)
  255. for _ in range(len(task_to_indices)):
  256. timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float('inf') else None
  257. task = finished_tasks.get(timeout=timeout)
  258. pending_tasks.discard(task)
  259. task_output = _process_dispatched_task(task, detect_anomalies)
  260. if task_output is not None:
  261. finished_indices.append(task_to_indices[task])
  262. finished_outputs.append(task_output)
  263. # count how many successes we have for each input sample
  264. sample_index = task_to_indices[task][0]
  265. num_successful_tasks[sample_index] += 1
  266. if num_successful_tasks[sample_index] == k_min:
  267. pending_samples -= 1
  268. if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min
  269. t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
  270. except Empty:
  271. pass # we reached t_finish, this is normal behavior
  272. finally:
  273. for task in pending_tasks:
  274. task.cancel()
  275. return finished_indices, finished_outputs
  276. def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
  277. if task.exception() or task.cancelled():
  278. logger.warning(f"Task {task} failed: {type(task.exception())}")
  279. return None
  280. deserialized_outputs = []
  281. for tensor in task.result().tensors:
  282. deserialized_tensor = deserialize_torch_tensor(tensor)
  283. if detect_anomalies and not deserialized_tensor.isfinite().all():
  284. logger.error(f"Task {task} failed: output tensor contains nan/inf values")
  285. return None
  286. deserialized_outputs.append(deserialized_tensor)
  287. return tuple(deserialized_outputs)