|
@@ -0,0 +1,153 @@
|
|
|
+import multiprocessing as mp
|
|
|
+import multiprocessing.pool
|
|
|
+from functools import partial
|
|
|
+from typing import Tuple, List, Dict, Any
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+
|
|
|
+from .remote_expert import RemoteExpert
|
|
|
+from ..utils import nested_map, check_numpy, run_and_await_k
|
|
|
+
|
|
|
+
|
|
|
+class GatingFunction(nn.Module):
|
|
|
+ def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
|
+ k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
|
|
|
+ super().__init__()
|
|
|
+ self.network, self.grid_size = network, grid_size
|
|
|
+ self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
|
+ self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
|
|
|
+
|
|
|
+ self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
|
|
|
+ self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
+
|
|
|
+ def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
|
|
|
+ """
|
|
|
+ Choose k best experts with beam search, then call chosen experts and average their outputs.
|
|
|
+ :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
|
|
|
+ :return: averaged predictions of all experts that delivered on time
|
|
|
+ """
|
|
|
+ assert len(input.shape) == 2
|
|
|
+
|
|
|
+ # 1. compute scores and find most appropriate experts with beam search
|
|
|
+ grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
|
|
|
+ batch_experts = self.beam_search(grid_scores, self.k_best)
|
|
|
+ # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
|
|
|
+
|
|
|
+ # 2.1 call chosen experts (run them in background to save time)
|
|
|
+ batch_outputs_async = [
|
|
|
+ self.thread_pool.apply_async(self._run_experts,
|
|
|
+ args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
|
|
|
+ kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
|
|
|
+ for i, chosen_experts in enumerate(batch_experts)
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 2.2 compute *differentiable* logits for each expert
|
|
|
+ batch_expert_logits = self._score_experts(grid_scores, batch_experts)
|
|
|
+ # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
|
|
|
+
|
|
|
+ batch_outputs = []
|
|
|
+ for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
|
|
|
+ expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
|
|
|
+ flat_experts, flat_outputs = zip(*expert_outputs.items())
|
|
|
+
|
|
|
+ # 3.1. normalize logits over only those experts that DID return output
|
|
|
+ flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
|
|
|
+ flat_weights = torch.softmax(flat_logits, dim=-1)
|
|
|
+
|
|
|
+ # 3.2. average each output across experts
|
|
|
+ average_outputs = nested_map(
|
|
|
+ lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
|
|
|
+
|
|
|
+ batch_outputs.append(average_outputs)
|
|
|
+
|
|
|
+ # 4. concatenate mixture outputs from individual experts
|
|
|
+ return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
|
|
|
+
|
|
|
+ def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
|
|
|
+ """
|
|
|
+ Find and return k best experts in the grid using (exact) beam search of the product space
|
|
|
+ :param grid_scores: scores predicted for each dimension in the grid,
|
|
|
+ :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
|
|
|
+ :param k_best: how many of the top experts participate in the computation
|
|
|
+ :param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
|
+ :returns: a list of *batch_size* lists that contain chosen experts for one sample
|
|
|
+ each inner list contains RemoteExpert instances for *up to* k_best experts
|
|
|
+ """
|
|
|
+ assert len(grid_scores) == len(self.grid_size)
|
|
|
+ assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
|
|
|
+ batch_size = len(grid_scores[0])
|
|
|
+ beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size]
|
|
|
+ scores = np.zeros([batch_size, 1], dtype=np.float64)
|
|
|
+
|
|
|
+ delimeters = np.array(self.network.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat
|
|
|
+
|
|
|
+ for dim_index, dim_scores in enumerate(grid_scores):
|
|
|
+ dim_scores = check_numpy(dim_scores)
|
|
|
+ assert dim_scores.shape[-1] == self.grid_size[dim_index]
|
|
|
+
|
|
|
+ # create all possible successsors from current beam
|
|
|
+ dim_indices = np.arange(dim_scores.shape[1]).astype(str)
|
|
|
+ new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
|
|
|
+ new_candidates = new_candidates.reshape([batch_size, -1])
|
|
|
+
|
|
|
+ new_scores = scores[:, :, None] + dim_scores[:, None, :]
|
|
|
+ new_scores = new_scores.reshape([batch_size, -1])
|
|
|
+
|
|
|
+ # select k best candidates according to scores but only those that are still active
|
|
|
+ new_order = np.argsort(- new_scores, axis=-1)
|
|
|
+ top_alive_lookups = [
|
|
|
+ self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
|
|
|
+ for cands, order in zip(new_candidates, new_order)]
|
|
|
+
|
|
|
+ batch_cand_to_score = [
|
|
|
+ dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
|
|
|
+
|
|
|
+ top_alive_prefixes = [result.get() for result in top_alive_lookups]
|
|
|
+ top_alive_scores = [list(map(cand_to_score.get, top_cands))
|
|
|
+ for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
|
|
|
+
|
|
|
+ # pad up to beam size
|
|
|
+ beam = np.array([row + [self.expert_padding] * (k_best - len(row))
|
|
|
+ for row in top_alive_prefixes], dtype='object')
|
|
|
+ scores = np.array([row + [-float('inf')] * (k_best - len(row))
|
|
|
+ for row in top_alive_scores], dtype='float32')
|
|
|
+
|
|
|
+ unique_experts = self.network.get_experts(list(set(
|
|
|
+ uid for row in beam for uid in row if uid != self.expert_padding)))
|
|
|
+ unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
|
|
|
+
|
|
|
+ return [
|
|
|
+ [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
|
|
|
+ for row in beam]
|
|
|
+
|
|
|
+ def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
|
|
|
+ outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
|
|
|
+ k=self.k_min, timeout_after_k=self.timeout_after_k_min)
|
|
|
+ return {expert: output for expert, output in zip(experts, outputs)
|
|
|
+ if not isinstance(output, BaseException)}
|
|
|
+
|
|
|
+ def _score_experts(self, grid_scores: List[torch.Tensor],
|
|
|
+ experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
|
|
|
+ flat_experts = [expert for row in experts for expert in row]
|
|
|
+ flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
|
|
|
+ for uid in range(len(row))])
|
|
|
+
|
|
|
+ grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
|
|
|
+ for i, expert in enumerate(flat_experts):
|
|
|
+ expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
|
|
|
+ expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
|
|
|
+ grid_indices[i] = expert_indices
|
|
|
+
|
|
|
+ scores_per_dim = [
|
|
|
+ dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
|
|
|
+ for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
|
|
|
+ flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
|
|
|
+
|
|
|
+ output_dicts = [dict() for _ in range(len(experts))]
|
|
|
+ for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
|
|
|
+ flat_experts, flat_scores):
|
|
|
+ output_dicts[batch_i][expert] = score
|
|
|
+
|
|
|
+ return output_dicts
|