moe.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. from __future__ import annotations
  2. import time
  3. from concurrent.futures import Future
  4. from queue import Empty, Queue
  5. from typing import Any, Dict, List, Optional, Tuple
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. from hivemind.compression import serialize_torch_tensor
  10. from hivemind.dht import DHT
  11. from hivemind.moe.client.beam_search import MoEBeamSearcher
  12. from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_server_stub
  13. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  14. from hivemind.moe.expert_uid import UID_DELIMITER
  15. from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
  16. from hivemind.utils import nested_flatten, nested_map, nested_pack
  17. from hivemind.utils.logging import get_logger
  18. logger = get_logger(__name__)
  19. class RemoteMixtureOfExperts(nn.Module):
  20. """
  21. A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
  22. Natively supports pytorch autograd.
  23. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  24. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  25. the missing experts
  26. :param in_features: common input size for experts and gating function
  27. :param grid_size: dimensions that form expert uid (see below)
  28. :param uid_prefix: common prefix for all expert uids (must end with '.')
  29. :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  30. :param dht: a DHT instance used to search for best experts
  31. :param k_best: average this many highest-scoring experts to compute activations
  32. :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
  33. :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
  34. Any expert that didn't manage to return output after that delay is considered unavailable
  35. :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
  36. :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
  37. """
  38. def __init__(
  39. self,
  40. *,
  41. in_features,
  42. grid_size: Tuple[int, ...],
  43. dht: DHT,
  44. uid_prefix: str,
  45. k_best: int,
  46. k_min: int = 1,
  47. forward_timeout: Optional[float] = None,
  48. timeout_after_k_min: Optional[float] = None,
  49. backward_k_min: int = 1,
  50. backward_timeout: Optional[float] = None,
  51. detect_anomalies: bool = False,
  52. allow_zero_outputs: bool = False,
  53. **dht_kwargs,
  54. ):
  55. super().__init__()
  56. self.dht = dht
  57. self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
  58. self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
  59. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  60. self.timeout_after_k_min = timeout_after_k_min
  61. self.detect_anomalies = detect_anomalies
  62. self.allow_zero_outputs = allow_zero_outputs
  63. # jointly predict logits for all grid dimensions
  64. self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
  65. self._expert_info = None # expert['info'] from one of experts in the grid
  66. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  67. """
  68. Choose k best experts with beam search, then call chosen experts and average their outputs.
  69. Input tensor is averaged over all dimensions except for first and last
  70. (we assume that extra dimensions represent sequence length or image height/width)
  71. :param input: a tensor of values that are used to estimate gating function, batch-first.
  72. :param args: extra positional parameters that will be passed to each expert after input, batch-first
  73. :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
  74. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  75. """
  76. if input.ndim != 2:
  77. input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
  78. else:
  79. input_for_gating = input
  80. # 1. compute scores and find most appropriate experts with beam search
  81. grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
  82. chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
  83. [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
  84. )
  85. if self._expert_info is None:
  86. try:
  87. self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
  88. except StopIteration:
  89. raise RuntimeError(
  90. "No responding experts found during beam search. Check that UID prefixes and "
  91. "the grid size are consistent with running Server instances."
  92. )
  93. except P2PDaemonError as e:
  94. logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
  95. expert_mask, *expert_outputs = _RemoteCallMany.apply(
  96. DUMMY,
  97. chosen_experts,
  98. self.k_min,
  99. self.backward_k_min,
  100. self.timeout_after_k_min,
  101. self.forward_timeout,
  102. self.backward_timeout,
  103. self.detect_anomalies,
  104. self.allow_zero_outputs,
  105. self.info,
  106. *nested_flatten(((input, *args), kwargs)),
  107. )
  108. # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
  109. expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
  110. masked_logits = torch.full((1,), float("-inf"), device=expert_logits.device, dtype=expert_logits.dtype)
  111. expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
  112. expert_weights = torch.softmax(expert_logits, dim=1)
  113. averaged_outputs_flat = [
  114. (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
  115. for tensor in expert_outputs
  116. ] # ^-- multiply by softmax weights along first 2 axes
  117. return nested_pack(averaged_outputs_flat, self.info["outputs_schema"])
  118. def compute_expert_scores(
  119. self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]
  120. ) -> torch.Tensor:
  121. """
  122. Compute scores for each expert by adding up grid scores, autograd-friendly
  123. :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
  124. :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
  125. :returns: a tensor of scores, float32[batch_size, k]
  126. :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
  127. """
  128. expert_counts = list(map(len, batch_experts))
  129. batch_size = len(batch_experts)
  130. max_num_experts = max(expert_counts)
  131. total_num_experts = sum(expert_counts)
  132. device = grid_scores[0].device
  133. expert_index_in_batch = torch.arange(total_num_experts, device=device)
  134. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
  135. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  136. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  137. flat_experts = [expert for row in batch_experts for expert in row]
  138. grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
  139. for i, expert in enumerate(flat_experts):
  140. expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
  141. expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
  142. grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
  143. scores_per_dim = [
  144. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
  145. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)
  146. ]
  147. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  148. scores = torch.full((batch_size, max_num_experts), fill_value=-float("inf"), device=device)
  149. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  150. return scores
  151. @property
  152. def info(self):
  153. if self._expert_info is None:
  154. # grab some expert to set ensemble output shape
  155. proj_device = self.proj.weight.device
  156. dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
  157. dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
  158. dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
  159. self._expert_info = dummy_experts[0].info
  160. return self._expert_info
  161. class _RemoteCallMany(torch.autograd.Function):
  162. """
  163. Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
  164. This function that can recover from individual failures during forward and/or backward pass as long as at least
  165. one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
  166. Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
  167. experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
  168. """
  169. @classmethod
  170. def forward(
  171. cls,
  172. ctx,
  173. dummy,
  174. experts_per_sample: List[List[RemoteExpert]],
  175. k_min: int,
  176. backward_k_min: int,
  177. timeout_after_k_min: float,
  178. forward_timeout: Optional[float],
  179. backward_timeout: Optional[float],
  180. detect_anomalies: bool,
  181. allow_zero_outputs: bool,
  182. info: Dict[str, Any],
  183. *flat_inputs: torch.Tensor,
  184. ) -> Tuple[torch.Tensor]:
  185. assert not torch.is_grad_enabled()
  186. num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
  187. flat_inputs_cpu = []
  188. for tensor in flat_inputs:
  189. if detect_anomalies and not tensor.isfinite().all():
  190. raise ValueError("One of inputs has nan/inf values")
  191. flat_inputs_cpu.append(tensor.cpu())
  192. flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
  193. assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
  194. # dispatch tasks to all remote experts collect responses
  195. pending_tasks: Dict[Future, Tuple[int, int]] = {}
  196. for i in range(num_samples):
  197. for j, expert in enumerate(experts_per_sample[i]):
  198. stub = get_server_stub(expert.p2p, expert.peer_id)
  199. serialized_tensors = (
  200. serialize_torch_tensor(tensor, proto.compression)
  201. for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
  202. )
  203. new_task = RemoteExpertWorker.run_coroutine(
  204. expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
  205. return_future=True,
  206. )
  207. pending_tasks[new_task] = (i, j)
  208. responded_inds, alive_flat_outputs = cls._collect_responses(
  209. pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
  210. )
  211. if len(responded_inds) < k_min:
  212. raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout")
  213. if not isinstance(info["outputs_schema"], tuple):
  214. outputs_schema = (info["outputs_schema"],)
  215. else:
  216. outputs_schema = info["outputs_schema"]
  217. outputs = nested_map(
  218. lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
  219. outputs_schema,
  220. )
  221. # assemble responses
  222. if len(responded_inds) > 0 or allow_zero_outputs:
  223. batch_inds, expert_inds = map(
  224. lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
  225. list(zip(*responded_inds)) or ([], []),
  226. )
  227. alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
  228. # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
  229. for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
  230. output[batch_inds, expert_inds] = response_stacked.to(output.device)
  231. else:
  232. raise RuntimeError("Forward pass: 0 experts responded within timeout and allow_zero_outputs is False")
  233. mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
  234. mask[batch_inds, expert_inds] = True
  235. # save individual outputs for backward pass
  236. ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
  237. ctx._saved_non_tensors = (
  238. info,
  239. backward_k_min,
  240. backward_timeout,
  241. timeout_after_k_min,
  242. experts_per_sample,
  243. detect_anomalies,
  244. )
  245. return (mask,) + outputs
  246. @classmethod
  247. @once_differentiable
  248. def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
  249. assert not torch.is_grad_enabled()
  250. (
  251. info,
  252. backward_k_min,
  253. backward_timeout,
  254. timeout_after_k_min,
  255. expert_per_sample,
  256. detect_anomalies,
  257. ) = ctx._saved_non_tensors
  258. alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
  259. dummy_grad_mask, *flat_grad_outputs = raw_grads
  260. flat_grad_outputs_cpu = []
  261. for tensor in flat_grad_outputs:
  262. if detect_anomalies and not tensor.isfinite().all():
  263. raise ValueError("One of gradients has nan/inf values")
  264. flat_grad_outputs_cpu.append(tensor.cpu())
  265. num_samples, max_experts = dummy_grad_mask.shape
  266. inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
  267. grad_outputs_per_expert = zip(
  268. *(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
  269. )
  270. backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
  271. # dispatch tasks to all remote experts, collect responses
  272. pending_tasks = {}
  273. for i, j, inputs_ij, grad_outputs_ij in zip(
  274. alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
  275. ):
  276. expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
  277. stub = get_server_stub(expert.p2p, expert.peer_id)
  278. inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
  279. serialized_tensors = (
  280. serialize_torch_tensor(tensor, proto.compression)
  281. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  282. )
  283. new_task = RemoteExpertWorker.run_coroutine(
  284. expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
  285. )
  286. pending_tasks[new_task] = (i, j)
  287. survivor_inds, survivor_grad_inputs = cls._collect_responses(
  288. pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
  289. )
  290. if len(survivor_inds) < backward_k_min:
  291. raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout")
  292. # assemble responses
  293. batch_inds, expert_inds = map(
  294. lambda x: torch.as_tensor(x, dtype=torch.long), list(zip(*survivor_inds)) or ([], [])
  295. )
  296. survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
  297. # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
  298. grad_inputs = nested_map(
  299. lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
  300. list(nested_flatten(info["forward_schema"])),
  301. )
  302. for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
  303. grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
  304. (num_samples, max_experts, *grad_input.shape[1:]),
  305. device=survivor_grad_stacked.device,
  306. dtype=survivor_grad_stacked.dtype,
  307. )
  308. grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
  309. grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
  310. return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
  311. @staticmethod
  312. def _collect_responses(
  313. task_to_indices: Dict[Future, Tuple[int, int]],
  314. num_samples: int,
  315. k_min: int,
  316. timeout_total: Optional[float],
  317. timeout_after_k_min: Optional[float],
  318. detect_anomalies: bool,
  319. ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
  320. """await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers"""
  321. timeout_total = float("inf") if timeout_total is None else timeout_total
  322. timeout_after_k_min = float("inf") if timeout_after_k_min is None else timeout_after_k_min
  323. num_successful_tasks = [0 for _ in range(num_samples)]
  324. pending_samples = num_samples # samples for which we have less than k_min results
  325. finished_indices, finished_outputs = [], []
  326. t_finish = time.perf_counter() + timeout_total
  327. pending_tasks = set(task_to_indices.keys())
  328. finished_tasks = Queue()
  329. try:
  330. # the algorithm below is essentially futures.as_completed, but for grpc.Future
  331. for task in pending_tasks:
  332. task.add_done_callback(finished_tasks.put)
  333. for _ in range(len(task_to_indices)):
  334. timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float("inf") else None
  335. task = finished_tasks.get(timeout=timeout)
  336. pending_tasks.discard(task)
  337. task_output = _process_dispatched_task(task, detect_anomalies)
  338. if task_output is not None:
  339. finished_indices.append(task_to_indices[task])
  340. finished_outputs.append(task_output)
  341. # count how many successes we have for each input sample
  342. sample_index = task_to_indices[task][0]
  343. num_successful_tasks[sample_index] += 1
  344. if num_successful_tasks[sample_index] == k_min:
  345. pending_samples -= 1
  346. if (
  347. pending_samples <= 0
  348. ): # all tasks finished, await stragglers for at most timeout_after_k_min
  349. t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
  350. except Empty:
  351. pass # we reached t_finish, this is normal behavior
  352. finally:
  353. for task in pending_tasks:
  354. task.cancel()
  355. return finished_indices, finished_outputs
  356. def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
  357. if task.exception() or task.cancelled():
  358. logger.warning(f"Task {task} failed: {type(task.exception())}")
  359. return None
  360. outputs = task.result()
  361. for tensor in outputs:
  362. if detect_anomalies and not tensor.isfinite().all():
  363. logger.error(f"Task {task} failed: output tensor contains nan/inf values")
  364. return None
  365. return outputs