gating_function.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import multiprocessing as mp
  2. import multiprocessing.pool
  3. from functools import partial
  4. from typing import Tuple, List, Dict, Any
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from .remote_expert import RemoteExpert
  9. from ..utils import nested_map, check_numpy, run_and_await_k
  10. class GatingFunction(nn.Module):
  11. def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
  12. k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
  13. """
  14. A torch module that selects experts across the network and averages their predictions
  15. :param in_features: common input size for experts and gating function
  16. :param grid_size: tesseract dimensions that form expert uid (see below)
  17. :param uid_prefix: common prefix for all expert uids
  18. expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
  19. :param network: TesseractNetwork where the experts reside
  20. :param num_workers: number of threads for parallel network operation
  21. :param k_best: queries this many experts with highest scores
  22. :param k_min: makes sure at least this many experts returned output
  23. :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
  24. Any expert that didn't manage to return output after that delay is considered unavailable
  25. :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
  26. """
  27. super().__init__()
  28. self.network, self.grid_size = network, grid_size
  29. self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
  30. self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
  31. self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
  32. self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
  33. def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
  34. """
  35. Choose k best experts with beam search, then call chosen experts and average their outputs.
  36. :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
  37. :returns: averaged predictions of all experts that delivered on time
  38. """
  39. assert len(input.shape) == 2
  40. # 1. compute scores and find most appropriate experts with beam search
  41. grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
  42. batch_experts = self.beam_search(grid_scores, self.k_best)
  43. # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
  44. # 2.1 call chosen experts (run them in background to save time)
  45. batch_outputs_async = [
  46. self.thread_pool.apply_async(self._run_experts,
  47. args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
  48. kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
  49. for i, chosen_experts in enumerate(batch_experts)
  50. ]
  51. # 2.2 compute *differentiable* logits for each expert
  52. batch_expert_logits = self._score_experts(grid_scores, batch_experts)
  53. # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
  54. batch_outputs = []
  55. for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
  56. expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
  57. flat_experts, flat_outputs = zip(*expert_outputs.items())
  58. # 3.1. normalize logits over only those experts that DID return output
  59. flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
  60. flat_weights = torch.softmax(flat_logits, dim=-1)
  61. # 3.2. average each output across experts
  62. average_outputs = nested_map(
  63. lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
  64. batch_outputs.append(average_outputs)
  65. # 4. concatenate mixture outputs from individual experts
  66. return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
  67. def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
  68. """
  69. Find and return k best experts in the grid using (exact) beam search of the product space
  70. :param grid_scores: scores predicted for each dimension in the grid,
  71. :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
  72. :param k_best: how many of the top experts participate in the computation
  73. :param kwargs: extra keyword parameters passed to self.network.first_k_active
  74. :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
  75. RemoteExpert instances for *up to* k_best experts
  76. """
  77. assert len(grid_scores) == len(self.grid_size)
  78. assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
  79. batch_size = len(grid_scores[0])
  80. beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size]
  81. scores = np.zeros([batch_size, 1], dtype=np.float64)
  82. delimeters = np.array(self.network.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat
  83. for dim_index, dim_scores in enumerate(grid_scores):
  84. dim_scores = check_numpy(dim_scores)
  85. assert dim_scores.shape[-1] == self.grid_size[dim_index]
  86. # create all possible successsors from current beam
  87. dim_indices = np.arange(dim_scores.shape[1]).astype(str)
  88. new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
  89. new_candidates = new_candidates.reshape([batch_size, -1])
  90. new_scores = scores[:, :, None] + dim_scores[:, None, :]
  91. new_scores = new_scores.reshape([batch_size, -1])
  92. # select k best candidates according to scores but only those that are still active
  93. new_order = np.argsort(- new_scores, axis=-1)
  94. top_alive_lookups = [
  95. self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
  96. for cands, order in zip(new_candidates, new_order)]
  97. batch_cand_to_score = [
  98. dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
  99. top_alive_prefixes = [result.get() for result in top_alive_lookups]
  100. top_alive_scores = [list(map(cand_to_score.get, top_cands))
  101. for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
  102. # pad up to beam size
  103. beam = np.array([row + [self.expert_padding] * (k_best - len(row))
  104. for row in top_alive_prefixes], dtype='object')
  105. scores = np.array([row + [-float('inf')] * (k_best - len(row))
  106. for row in top_alive_scores], dtype='float32')
  107. unique_experts = self.network.get_experts(list(set(
  108. uid for row in beam for uid in row if uid != self.expert_padding)))
  109. unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
  110. return [
  111. [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
  112. for row in beam]
  113. def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
  114. outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
  115. k=self.k_min, timeout_after_k=self.timeout_after_k_min)
  116. return {expert: output for expert, output in zip(experts, outputs)
  117. if not isinstance(output, BaseException)}
  118. def _score_experts(self, grid_scores: List[torch.Tensor],
  119. experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
  120. flat_experts = [expert for row in experts for expert in row]
  121. flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
  122. for uid in range(len(row))])
  123. grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
  124. for i, expert in enumerate(flat_experts):
  125. expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
  126. expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
  127. grid_indices[i] = expert_indices
  128. scores_per_dim = [
  129. dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
  130. for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
  131. flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
  132. output_dicts = [dict() for _ in range(len(experts))]
  133. for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
  134. flat_experts, flat_scores):
  135. output_dicts[batch_i][expert] = score
  136. return output_dicts