123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import multiprocessing as mp
- import multiprocessing.pool
- from functools import partial
- from typing import Tuple, List, Dict, Any
- import numpy as np
- import torch
- import torch.nn as nn
- from .remote_expert import RemoteExpert
- from ..utils import nested_map, check_numpy, run_and_await_k
- class GatingFunction(nn.Module):
- def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
- k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
- super().__init__()
- self.network, self.grid_size = network, grid_size
- self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
- self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
- self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
- self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
- def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
- """
- Choose k best experts with beam search, then call chosen experts and average their outputs.
- :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
- :return: averaged predictions of all experts that delivered on time
- """
- assert len(input.shape) == 2
- # 1. compute scores and find most appropriate experts with beam search
- grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
- batch_experts = self.beam_search(grid_scores, self.k_best)
- # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
- # 2.1 call chosen experts (run them in background to save time)
- batch_outputs_async = [
- self.thread_pool.apply_async(self._run_experts,
- args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
- kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
- for i, chosen_experts in enumerate(batch_experts)
- ]
- # 2.2 compute *differentiable* logits for each expert
- batch_expert_logits = self._score_experts(grid_scores, batch_experts)
- # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
- batch_outputs = []
- for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
- expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
- flat_experts, flat_outputs = zip(*expert_outputs.items())
- # 3.1. normalize logits over only those experts that DID return output
- flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
- flat_weights = torch.softmax(flat_logits, dim=-1)
- # 3.2. average each output across experts
- average_outputs = nested_map(
- lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
- batch_outputs.append(average_outputs)
- # 4. concatenate mixture outputs from individual experts
- return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
- def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
- """
- Find and return k best experts in the grid using (exact) beam search of the product space
- :param grid_scores: scores predicted for each dimension in the grid,
- :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
- :param k_best: how many of the top experts participate in the computation
- :param kwargs: extra keyword parameters passed to self.network.first_k_active
- :returns: a list of *batch_size* lists that contain chosen experts for one sample
- each inner list contains RemoteExpert instances for *up to* k_best experts
- """
- assert len(grid_scores) == len(self.grid_size)
- assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
- batch_size = len(grid_scores[0])
- beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size]
- scores = np.zeros([batch_size, 1], dtype=np.float64)
- delimeters = np.array(self.network.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat
- for dim_index, dim_scores in enumerate(grid_scores):
- dim_scores = check_numpy(dim_scores)
- assert dim_scores.shape[-1] == self.grid_size[dim_index]
- # create all possible successsors from current beam
- dim_indices = np.arange(dim_scores.shape[1]).astype(str)
- new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
- new_candidates = new_candidates.reshape([batch_size, -1])
- new_scores = scores[:, :, None] + dim_scores[:, None, :]
- new_scores = new_scores.reshape([batch_size, -1])
- # select k best candidates according to scores but only those that are still active
- new_order = np.argsort(- new_scores, axis=-1)
- top_alive_lookups = [
- self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
- for cands, order in zip(new_candidates, new_order)]
- batch_cand_to_score = [
- dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
- top_alive_prefixes = [result.get() for result in top_alive_lookups]
- top_alive_scores = [list(map(cand_to_score.get, top_cands))
- for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
- # pad up to beam size
- beam = np.array([row + [self.expert_padding] * (k_best - len(row))
- for row in top_alive_prefixes], dtype='object')
- scores = np.array([row + [-float('inf')] * (k_best - len(row))
- for row in top_alive_scores], dtype='float32')
- unique_experts = self.network.get_experts(list(set(
- uid for row in beam for uid in row if uid != self.expert_padding)))
- unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
- return [
- [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
- for row in beam]
- def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
- outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
- k=self.k_min, timeout_after_k=self.timeout_after_k_min)
- return {expert: output for expert, output in zip(experts, outputs)
- if not isinstance(output, BaseException)}
- def _score_experts(self, grid_scores: List[torch.Tensor],
- experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
- flat_experts = [expert for row in experts for expert in row]
- flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
- for uid in range(len(row))])
- grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
- for i, expert in enumerate(flat_experts):
- expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
- expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
- grid_indices[i] = expert_indices
- scores_per_dim = [
- dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
- for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
- flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
- output_dicts = [dict() for _ in range(len(experts))]
- for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
- flat_experts, flat_scores):
- output_dicts[batch_i][expert] = score
- return output_dicts
|