|
@@ -1,14 +1,19 @@
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.pool
|
|
|
-from concurrent.futures import as_completed
|
|
|
-from typing import Tuple, List, Dict, Any
|
|
|
+from concurrent.futures import as_completed, TimeoutError, Future
|
|
|
+from functools import partial
|
|
|
+from itertools import chain
|
|
|
+from typing import Tuple, List, Dict, Any, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+from torch.autograd.function import once_differentiable
|
|
|
|
|
|
-from .expert import RemoteExpert
|
|
|
-from ..utils import nested_map, check_numpy, run_in_background
|
|
|
+from .expert import RemoteExpert, _RemoteModuleCall
|
|
|
+from ..utils import nested_map, check_numpy, run_in_background, run_and_await_k, nested_pack, BatchTensorProto, \
|
|
|
+ nested_flatten, DUMMY
|
|
|
+from ..utils.autograd import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward
|
|
|
|
|
|
|
|
|
class RemoteMixtureOfExperts(nn.Module):
|
|
@@ -140,29 +145,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
uid for row in beam for uid in row if uid != self.expert_padding)))
|
|
|
unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
|
|
|
|
|
|
- return [
|
|
|
- [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
|
|
|
- for row in beam]
|
|
|
-
|
|
|
- def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
|
|
|
- future_to_expert = {run_in_background(expert, *args, **kwargs): expert for expert in experts}
|
|
|
- pending_futures = set(future_to_expert.keys())
|
|
|
- outputs = {} # {expert -> output}
|
|
|
-
|
|
|
- # await first k futures for as long as it takes
|
|
|
- for future in as_completed(list(pending_futures), timeout=None):
|
|
|
- pending_futures.remove(future)
|
|
|
- if not future.exception():
|
|
|
- outputs[future_to_expert.pop(future)] = future.result()
|
|
|
- if len(outputs) > self.k_min:
|
|
|
- break
|
|
|
-
|
|
|
- # await stragglers for at most self.timeout_after_k_min
|
|
|
- for future in as_completed(pending_futures, timeout=self.timeout_after_k_min):
|
|
|
- if not future.exception():
|
|
|
- outputs[future_to_expert.pop(future)] = future.result()
|
|
|
-
|
|
|
- return outputs
|
|
|
+ return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
|
|
|
|
|
|
def _score_experts(self, grid_scores: List[torch.Tensor],
|
|
|
experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
|
|
@@ -186,3 +169,74 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
output_dicts[batch_i][expert] = score
|
|
|
|
|
|
return output_dicts
|
|
|
+
|
|
|
+
|
|
|
+class _RemoteMoECall(torch.autograd.Function):
|
|
|
+ """
|
|
|
+ Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
|
|
|
+ This function that can recover from individual failures during forward and/or backward passes.
|
|
|
+ For user-friendly version of this function, use RemoteMixtureOfExperts module.
|
|
|
+ """
|
|
|
+ MIN_TOTAL_WEIGHT = 1e-3
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
|
|
|
+ *flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
|
|
|
+ timeout_total: Optional[float], backward_timeout: Optional[float]) -> Tuple[torch.Tensor]:
|
|
|
+ expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
|
|
|
+ assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
|
|
|
+
|
|
|
+ # 1. call experts and await results
|
|
|
+ jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
|
|
|
+ results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
|
|
|
+
|
|
|
+ alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
|
|
|
+ if not isinstance(result, BaseException)])
|
|
|
+ # ^ ^ ^-- a list of indices of experts that returned outputs in time
|
|
|
+ # \ \-- list of outputs of every expert that didn't die on us
|
|
|
+ # \-- a list of autograd contexts, used for parallel backward
|
|
|
+
|
|
|
+ # 2. compute softmax weights for alive experts and average outputs
|
|
|
+ alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
|
|
|
+
|
|
|
+ flat_average_outputs = tuple(map(
|
|
|
+ lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
|
|
|
+
|
|
|
+ # 3. save individual outputs for backward pass
|
|
|
+ ctx.save_for_backward(flat_inputs, expert_logits, alive_ix, alive_expert_probs)
|
|
|
+ ctx._alive_contexts = alive_contexts
|
|
|
+ ctx._backward_k_min = backward_k_min
|
|
|
+ ctx._backward_timeout = backward_timeout
|
|
|
+ return tuple(map(torch.Tensor.detach, flat_average_outputs))
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ @once_differentiable
|
|
|
+ def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
+ """ Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
+ flat_inputs, expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
|
+ alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
+
|
|
|
+ jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)
|
|
|
+ for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
|
+ results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
|
+ survived_backward, survived_grad_inputs = zip(*(alive_ix[i], grads for i, grads in enumerate(results)))
|
|
|
+
|
|
|
+ survived_ix = alive_ix[survived_backward]
|
|
|
+ survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
|
|
|
+
|
|
|
+ flat_grad_inputs = tuple(map(
|
|
|
+ lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),
|
|
|
+ *survived_grad_inputs))
|
|
|
+
|
|
|
+ grad_logits = None # TODO
|
|
|
+ return (grad_logits, None, *flat_grad_inputs, None, None, None, None, None, None)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
+ """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
|
|
|
+ flat_inputs = nested_flatten((args, kwargs))
|
|
|
+ return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, *flat_inputs)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
|
|
|
+ return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|