moe.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. from __future__ import annotations
  2. import time
  3. from queue import Empty, Queue
  4. from typing import Any, Dict, List, Optional, Tuple
  5. import grpc
  6. import torch
  7. import torch.nn as nn
  8. from torch.autograd.function import once_differentiable
  9. import hivemind
  10. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  11. from hivemind.moe.client.beam_search import MoEBeamSearcher
  12. from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
  13. from hivemind.moe.server.expert_uid import UID_DELIMITER
  14. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  15. from hivemind.utils import nested_compare, nested_flatten, nested_map, nested_pack
  16. from hivemind.utils.logging import get_logger
  17. logger = get_logger(__name__)
  18. class RemoteMixtureOfExperts(nn.Module):
  19. """
  20. A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts.
  21. Natively supports pytorch autograd.
  22. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
  23. forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
  24. the missing experts
  25. :param in_features: common input size for experts and gating function
  26. :param grid_size: dimensions that form expert uid (see below)
  27. :param uid_prefix: common prefix for all expert uids (must end with '.')
  28. :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  29. :param dht: a DHT instance used to search for best experts
  30. :param k_best: average this many highest-scoring experts to compute activations
  31. :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
  32. :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
  33. Any expert that didn't manage to return output after that delay is considered unavailable
  34. :param detect_anomalies: whether to check input/output tensors for NaN and infinity values
  35. :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
  36. """
  37. def __init__(
  38. self,
  39. *,
  40. in_features,
  41. grid_size: Tuple[int, ...],
  42. dht: hivemind.DHT,
  43. uid_prefix: str,
  44. k_best: int,
  45. k_min: int = 1,
  46. forward_timeout: Optional[float] = None,
  47. timeout_after_k_min: Optional[float] = None,
  48. backward_k_min: int = 1,
  49. backward_timeout: Optional[float] = None,
  50. detect_anomalies: bool = False,
  51. allow_zero_outputs: bool = False,
  52. **dht_kwargs,
  53. ):
  54. super().__init__()
  55. self.dht = dht
  56. self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
  57. self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
  58. self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
  59. self.timeout_after_k_min = timeout_after_k_min
  60. self.detect_anomalies = detect_anomalies
  61. self.allow_zero_outputs = allow_zero_outputs
  62. # jointly predict logits for all grid dimensions
  63. self.proj = nn.Linear(in_features, self.beam_search.total_grid_size)
  64. self._expert_info = None # expert['info'] from one of experts in the grid
  65. def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
  66. """
  67. Choose k best experts with beam search, then call chosen experts and average their outputs.
  68. Input tensor is averaged over all dimensions except for first and last
  69. (we assume that extra dimensions represent sequence length or image height/width)
  70. :param input: a tensor of values that are used to estimate gating function, batch-first.
  71. :param args: extra positional parameters that will be passed to each expert after input, batch-first
  72. :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
  73. :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
  74. """
  75. # 1. compute scores and find most appropriate experts with beam search
  76. grid_scores = [torch.randn(input.size(0), dim) for dim in self.beam_search.grid_size]
  77. chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
  78. [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
  79. )
  80. if self._expert_info is None:
  81. try:
  82. self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
  83. except StopIteration:
  84. raise RuntimeError(
  85. "No responding experts found during beam search. Check that UID prefixes and "
  86. "the grid size are consistent with running Server instances."
  87. )
  88. except grpc.RpcError as e:
  89. logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
  90. flat_inputs = nested_flatten(((input, *args), kwargs))
  91. expert_mask, *expert_outputs = _RemoteCallMany.apply(
  92. DUMMY,
  93. chosen_experts,
  94. self.k_min,
  95. self.backward_k_min,
  96. self.timeout_after_k_min,
  97. self.forward_timeout,
  98. self.backward_timeout,
  99. self.detect_anomalies,
  100. self.allow_zero_outputs,
  101. self.info,
  102. *flat_inputs,
  103. )
  104. # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
  105. expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
  106. masked_logits = torch.full((1,), float("-inf"), device=expert_logits.device, dtype=expert_logits.dtype)
  107. expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
  108. expert_weights = torch.softmax(expert_logits, dim=1)
  109. averaged_outputs_flat = [
  110. (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
  111. for tensor in expert_outputs
  112. ] # ^-- multiply by softmax weights along first 2 axes
  113. return nested_pack(averaged_outputs_flat, self.info["outputs_schema"])
  114. def compute_expert_scores(
  115. self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]
  116. ) -> torch.Tensor:
  117. """
  118. Compute scores for each expert by adding up grid scores, autograd-friendly
  119. :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
  120. :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
  121. :returns: a tensor of scores, float32[batch_size, k]
  122. :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
  123. """
  124. expert_counts = list(map(len, batch_experts))
  125. batch_size = len(batch_experts)
  126. max_num_experts = max(expert_counts)
  127. total_num_experts = sum(expert_counts)
  128. device = grid_scores[0].device
  129. expert_index_in_batch = torch.arange(total_num_experts, device=device)
  130. expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
  131. flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
  132. flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
  133. flat_experts = [expert for row in batch_experts for expert in row]
  134. grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
  135. for i, expert in enumerate(flat_experts):
  136. expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
  137. expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
  138. grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
  139. scores_per_dim = [
  140. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
  141. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)
  142. ]
  143. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  144. scores = torch.full((batch_size, max_num_experts), fill_value=-float("inf"), device=device)
  145. scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
  146. return scores
  147. @property
  148. def info(self):
  149. if self._expert_info is None:
  150. # grab some expert to set ensemble output shape
  151. proj_device = self.proj.weight.device
  152. dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
  153. dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
  154. dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
  155. self._expert_info = dummy_experts[0].info
  156. return self._expert_info
  157. class _RemoteCallMany(torch.autograd.Function):
  158. """
  159. Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
  160. This function that can recover from individual failures during forward and/or backward pass as long as at least
  161. one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
  162. Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
  163. experts that failed during backward will be treated as constants (i.e. gradients through them are zeros)
  164. """
  165. @classmethod
  166. def forward(
  167. cls,
  168. ctx,
  169. dummy,
  170. experts_per_sample: List[List[RemoteExpert]],
  171. k_min: int,
  172. backward_k_min: int,
  173. timeout_after_k_min: float,
  174. forward_timeout: Optional[float],
  175. backward_timeout: Optional[float],
  176. detect_anomalies: bool,
  177. allow_zero_outputs: bool,
  178. info: Dict[str, Any],
  179. *flat_inputs: torch.Tensor,
  180. ) -> Tuple[torch.Tensor]:
  181. assert not torch.is_grad_enabled()
  182. num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
  183. flat_inputs_cpu = []
  184. for tensor in flat_inputs:
  185. if detect_anomalies and not tensor.isfinite().all():
  186. raise ValueError("One of inputs has nan/inf values")
  187. flat_inputs_cpu.append(tensor.cpu())
  188. flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
  189. assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
  190. # dispatch tasks to all remote experts collect responses
  191. pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
  192. for i in range(num_samples):
  193. for j, expert in enumerate(experts_per_sample[i]):
  194. input_tensors = [
  195. serialize_torch_tensor(tensor, proto.compression)
  196. for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
  197. ]
  198. stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
  199. new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
  200. pending_tasks[new_task] = (i, j)
  201. responded_inds, alive_flat_outputs = cls._collect_responses(
  202. pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
  203. )
  204. if len(responded_inds) < k_min:
  205. raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
  206. if not isinstance(info["outputs_schema"], tuple):
  207. outputs_schema = (info["outputs_schema"],)
  208. else:
  209. outputs_schema = info["outputs_schema"]
  210. outputs = nested_map(
  211. lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
  212. outputs_schema,
  213. )
  214. # assemble responses
  215. if len(responded_inds) > 0 or allow_zero_outputs:
  216. batch_inds, expert_inds = map(
  217. lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
  218. list(zip(*responded_inds)) or ([], []),
  219. )
  220. alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
  221. # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
  222. for output, response_stacked in zip(outputs, alive_flat_outputs_stacked):
  223. output[batch_inds, expert_inds] = response_stacked.to(output.device)
  224. else:
  225. raise RuntimeError("Forward pass: 0 experts responded within timeout and allow_zero_outputs is False")
  226. mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
  227. mask[batch_inds, expert_inds] = True
  228. # save individual outputs for backward pass
  229. ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
  230. ctx._saved_non_tensors = (
  231. info,
  232. backward_k_min,
  233. backward_timeout,
  234. timeout_after_k_min,
  235. experts_per_sample,
  236. detect_anomalies,
  237. )
  238. return (mask,) + outputs
  239. @classmethod
  240. @once_differentiable
  241. def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
  242. assert not torch.is_grad_enabled()
  243. (
  244. info,
  245. backward_k_min,
  246. backward_timeout,
  247. timeout_after_k_min,
  248. expert_per_sample,
  249. detect_anomalies,
  250. ) = ctx._saved_non_tensors
  251. alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
  252. dummy_grad_mask, *flat_grad_outputs = raw_grads
  253. flat_grad_outputs_cpu = []
  254. for tensor in flat_grad_outputs:
  255. if detect_anomalies and not tensor.isfinite().all():
  256. raise ValueError("One of gradients has nan/inf values")
  257. flat_grad_outputs_cpu.append(tensor.cpu())
  258. num_samples, max_experts = dummy_grad_mask.shape
  259. inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
  260. grad_outputs_per_expert = zip(
  261. *(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
  262. )
  263. backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
  264. # dispatch tasks to all remote experts, collect responses
  265. pending_tasks = {}
  266. for i, j, inputs_ij, grad_outputs_ij in zip(
  267. alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
  268. ):
  269. expert = expert_per_sample[i.item()][j.item()]
  270. stub = _get_expert_stub(expert.endpoint)
  271. inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
  272. tensors_serialized = [
  273. serialize_torch_tensor(tensor, proto.compression)
  274. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  275. ]
  276. new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
  277. pending_tasks[new_task] = (i, j)
  278. survivor_inds, survivor_grad_inputs = cls._collect_responses(
  279. pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
  280. )
  281. if len(survivor_inds) < backward_k_min:
  282. raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
  283. # assemble responses
  284. batch_inds, expert_inds = map(
  285. lambda x: torch.as_tensor(x, dtype=torch.long), list(zip(*survivor_inds)) or ([], [])
  286. )
  287. survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
  288. # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
  289. grad_inputs = nested_map(
  290. lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
  291. list(nested_flatten(info["forward_schema"])),
  292. )
  293. for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
  294. grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
  295. (num_samples, max_experts, *grad_input.shape[1:]),
  296. device=survivor_grad_stacked.device,
  297. dtype=survivor_grad_stacked.dtype,
  298. )
  299. grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
  300. grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
  301. return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
  302. @staticmethod
  303. def _collect_responses(
  304. task_to_indices: Dict[grpc.Future, Tuple[int, int]],
  305. num_samples: int,
  306. k_min: int,
  307. timeout_total: Optional[float],
  308. timeout_after_k_min: Optional[float],
  309. detect_anomalies: bool,
  310. ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
  311. """await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers"""
  312. timeout_total = float("inf") if timeout_total is None else timeout_total
  313. timeout_after_k_min = float("inf") if timeout_after_k_min is None else timeout_after_k_min
  314. num_successful_tasks = [0 for _ in range(num_samples)]
  315. pending_samples = num_samples # samples for which we have less than k_min results
  316. finished_indices, finished_outputs = [], []
  317. t_finish = time.perf_counter() + timeout_total
  318. pending_tasks = set(task_to_indices.keys())
  319. finished_tasks = Queue()
  320. try:
  321. # the algorithm below is essentially futures.as_completed, but for grpc.Future
  322. for task in pending_tasks:
  323. task.add_done_callback(finished_tasks.put)
  324. for _ in range(len(task_to_indices)):
  325. timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float("inf") else None
  326. task = finished_tasks.get(timeout=timeout)
  327. pending_tasks.discard(task)
  328. task_output = _process_dispatched_task(task, detect_anomalies)
  329. if task_output is not None:
  330. finished_indices.append(task_to_indices[task])
  331. finished_outputs.append(task_output)
  332. # count how many successes we have for each input sample
  333. sample_index = task_to_indices[task][0]
  334. num_successful_tasks[sample_index] += 1
  335. if num_successful_tasks[sample_index] == k_min:
  336. pending_samples -= 1
  337. if (
  338. pending_samples <= 0
  339. ): # all tasks finished, await stragglers for at most timeout_after_k_min
  340. t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
  341. except Empty:
  342. pass # we reached t_finish, this is normal behavior
  343. finally:
  344. for task in pending_tasks:
  345. task.cancel()
  346. return finished_indices, finished_outputs
  347. def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
  348. if task.exception() or task.cancelled():
  349. logger.warning(f"Task {task} failed: {type(task.exception())}")
  350. return None
  351. deserialized_outputs = []
  352. for tensor in task.result().tensors:
  353. deserialized_tensor = deserialize_torch_tensor(tensor)
  354. if detect_anomalies and not deserialized_tensor.isfinite().all():
  355. logger.error(f"Task {task} failed: output tensor contains nan/inf values")
  356. return None
  357. deserialized_outputs.append(deserialized_tensor)
  358. return tuple(deserialized_outputs)