|
@@ -1,7 +1,7 @@
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.pool
|
|
|
from functools import partial
|
|
|
-from typing import Tuple, List, Dict, Any, Optional
|
|
|
+from typing import Tuple, List, Dict, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
@@ -10,7 +10,7 @@ from torch.autograd.function import once_differentiable
|
|
|
|
|
|
from .expert import RemoteExpert, _RemoteModuleCall
|
|
|
from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY
|
|
|
-from ..utils.autograd import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward
|
|
|
+from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
|
|
|
|
|
|
|
|
|
class RemoteMixtureOfExperts(nn.Module):
|
|
@@ -35,58 +35,57 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
"""
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
|
- k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
|
|
|
+ k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
|
|
|
+ uid_prefix='', expert_padding=None):
|
|
|
super().__init__()
|
|
|
self.network, self.grid_size = network, grid_size
|
|
|
self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
|
- self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
|
|
|
+ self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
+ self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
|
|
|
|
|
|
self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
|
|
|
self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
|
|
|
- def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
|
|
|
+ # grab some expert to set ensemble output shape
|
|
|
+ dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)
|
|
|
+ self.output_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['output_schema']
|
|
|
+
|
|
|
+ def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|
|
|
Choose k best experts with beam search, then call chosen experts and average their outputs.
|
|
|
-
|
|
|
- :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
|
|
|
- :returns: averaged predictions of all experts that delivered on time
|
|
|
+ :param input: a tensor of values that are used to estimate gating function, batch-first
|
|
|
+ :param args: extra positional parameters that will be passed to each expert after input, batch-first
|
|
|
+ :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
|
|
|
+ :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
|
|
|
"""
|
|
|
- assert len(input.shape) == 2
|
|
|
+ if self.allow_broadcasting and input.shape != 2:
|
|
|
+ # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
|
|
|
+ flattened_dims = input.shape[:-1]
|
|
|
+ input_flat = input.view(-1, input.shape[-1])
|
|
|
+ args_flat = [tensor.view(-1, tensor.shape[len(flattened_dims):]) for tensor in args]
|
|
|
+ kwargs_flat = {key: tensor.view(-1, tensor.shape[len(flattened_dims):]) for key, tensor in kwargs.items()}
|
|
|
+ out_flat = self.forward(input_flat, *args_flat, **kwargs_flat)
|
|
|
+ return nested_map(lambda tensor: tensor.view(flattened_dims, tensor.shape[len(flattened_dims):]), out_flat)
|
|
|
|
|
|
# 1. compute scores and find most appropriate experts with beam search
|
|
|
grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
|
|
|
- batch_experts = self.beam_search(grid_scores, self.k_best)
|
|
|
+ chosen_experts = self.beam_search(grid_scores, self.k_best)
|
|
|
# ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
|
|
|
|
|
|
- # 2.1 call chosen experts (run them in background to save time)
|
|
|
- batch_outputs_async = [
|
|
|
- self.thread_pool.apply_async(self._run_experts,
|
|
|
- args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
|
|
|
- kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
|
|
|
- for i, chosen_experts in enumerate(batch_experts)
|
|
|
- ]
|
|
|
-
|
|
|
- # 2.2 compute *differentiable* logits for each expert
|
|
|
- batch_expert_logits = self._score_experts(grid_scores, batch_experts)
|
|
|
- # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
|
|
|
-
|
|
|
- batch_outputs = []
|
|
|
- for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
|
|
|
- expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
|
|
|
- flat_experts, flat_outputs = zip(*expert_outputs.items())
|
|
|
-
|
|
|
- # 3.1. normalize logits over only those experts that DID return output
|
|
|
- flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
|
|
|
- flat_weights = torch.softmax(flat_logits, dim=-1)
|
|
|
+ expert_logits = self._score_experts(grid_scores, chosen_experts)
|
|
|
|
|
|
- # 3.2. average each output across experts
|
|
|
- average_outputs = nested_map(
|
|
|
- lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
|
|
|
+ expert_inputs = ((input, *args), kwargs)
|
|
|
+ input_schema = nested_map(lambda x: None, expert_inputs)
|
|
|
+ flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
|
|
|
|
|
|
- batch_outputs.append(average_outputs)
|
|
|
+ batch_jobs_args = tuple(
|
|
|
+ (expert_logits[i], chosen_experts[i], self.k_min, self.timeout_after_k_min,
|
|
|
+ self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
|
|
|
+ for i in range(len(input))
|
|
|
+ )
|
|
|
|
|
|
- # 4. concatenate mixture outputs from individual experts
|
|
|
- return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
|
|
|
+ averaged_outputs_flat = map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)
|
|
|
+ return nested_pack(averaged_outputs_flat, self.outputs_schema)
|
|
|
|
|
|
def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
|
|
|
"""
|
|
@@ -202,9 +201,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
# 3. save individual outputs for backward pass
|
|
|
ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|
|
|
- ctx._alive_contexts = alive_contexts
|
|
|
- ctx._backward_k_min = backward_k_min
|
|
|
- ctx._backward_timeout = backward_timeout
|
|
|
+ ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
|
|
|
return tuple(map(torch.Tensor.detach, flat_average_outputs))
|
|
|
|
|
|
@classmethod
|
|
@@ -212,11 +209,11 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
""" Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
|
|
|
- alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
+ alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
|
|
|
|
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
|
- results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
|
+ results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=None, timeout_total=backward_timeout)
|
|
|
backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
|
|
|
backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
|
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|