소스 검색

allow batch-parallel backprop

justheuristic 5 년 전
부모
커밋
cb77e99643
3개의 변경된 파일103개의 추가작업 그리고 43개의 파일을 삭제
  1. 38 41
      tesseract/client/moe.py
  2. 1 0
      tesseract/utils/__init__.py
  3. 64 2
      tesseract/utils/autograd.py

+ 38 - 41
tesseract/client/moe.py

@@ -1,7 +1,7 @@
 import multiprocessing as mp
 import multiprocessing.pool
 from functools import partial
-from typing import Tuple, List, Dict, Any, Optional
+from typing import Tuple, List, Dict, Optional
 
 import numpy as np
 import torch
@@ -10,7 +10,7 @@ from torch.autograd.function import once_differentiable
 
 from .expert import RemoteExpert, _RemoteModuleCall
 from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY
-from ..utils.autograd import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward
+from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
 class RemoteMixtureOfExperts(nn.Module):
@@ -35,58 +35,57 @@ class RemoteMixtureOfExperts(nn.Module):
     :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
     """
     def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
-                 k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
+                 k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
+                 uid_prefix='', expert_padding=None):
         super().__init__()
         self.network, self.grid_size = network, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
-        self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
+        self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
+        self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
 
         self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
 
-    def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
+        # grab some expert to set ensemble output shape
+        dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)
+        self.output_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['output_schema']
+
+    def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
         Choose k best experts with beam search, then call chosen experts and average their outputs.
-
-        :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
-        :returns: averaged predictions of all experts that delivered on time
+        :param input: a tensor of values that are used to estimate gating function, batch-first
+        :param args: extra positional parameters that will be passed to each expert after input, batch-first
+        :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
+        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
         """
-        assert len(input.shape) == 2
+        if self.allow_broadcasting and input.shape != 2:
+            # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
+            flattened_dims = input.shape[:-1]
+            input_flat = input.view(-1, input.shape[-1])
+            args_flat = [tensor.view(-1, tensor.shape[len(flattened_dims):]) for tensor in args]
+            kwargs_flat = {key: tensor.view(-1, tensor.shape[len(flattened_dims):]) for key, tensor in kwargs.items()}
+            out_flat = self.forward(input_flat, *args_flat, **kwargs_flat)
+            return nested_map(lambda tensor: tensor.view(flattened_dims, tensor.shape[len(flattened_dims):]), out_flat)
 
         # 1. compute scores and find most appropriate experts with beam search
         grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
-        batch_experts = self.beam_search(grid_scores, self.k_best)
+        chosen_experts = self.beam_search(grid_scores, self.k_best)
         # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
-        # 2.1 call chosen experts (run them in background to save time)
-        batch_outputs_async = [
-            self.thread_pool.apply_async(self._run_experts,
-                                         args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
-                                         kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
-            for i, chosen_experts in enumerate(batch_experts)
-        ]
-
-        # 2.2 compute *differentiable* logits for each expert
-        batch_expert_logits = self._score_experts(grid_scores, batch_experts)
-        # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
-
-        batch_outputs = []
-        for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
-            expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
-            flat_experts, flat_outputs = zip(*expert_outputs.items())
-
-            # 3.1. normalize logits over only those experts that DID return output
-            flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
-            flat_weights = torch.softmax(flat_logits, dim=-1)
+        expert_logits = self._score_experts(grid_scores, chosen_experts)
 
-            # 3.2. average each output across experts
-            average_outputs = nested_map(
-                lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
+        expert_inputs = ((input, *args), kwargs)
+        input_schema = nested_map(lambda x: None, expert_inputs)
+        flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
 
-            batch_outputs.append(average_outputs)
+        batch_jobs_args = tuple(
+            (expert_logits[i], chosen_experts[i], self.k_min, self.timeout_after_k_min,
+             self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
+            for i in range(len(input))
+        )
 
-        # 4. concatenate mixture outputs from individual experts
-        return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
+        averaged_outputs_flat = map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)
+        return nested_pack(averaged_outputs_flat, self.outputs_schema)
 
     def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
         """
@@ -202,9 +201,7 @@ class _RemoteMoECall(torch.autograd.Function):
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
-        ctx._alive_contexts = alive_contexts
-        ctx._backward_k_min = backward_k_min
-        ctx._backward_timeout = backward_timeout
+        ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
         return tuple(map(torch.Tensor.detach, flat_average_outputs))
 
     @classmethod
@@ -212,11 +209,11 @@ class _RemoteMoECall(torch.autograd.Function):
     def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         """ Like normal backward, but we ignore any experts that failed during backward pass """
         expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
-        alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
+        alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
 
         jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
-        results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
+        results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=None, timeout_total=backward_timeout)
         backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
         backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
         backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]

+ 1 - 0
tesseract/utils/__init__.py

@@ -5,3 +5,4 @@ from .proto import *
 from .serializer import *
 from .shared_future import *
 from .threading import *
+from .autograd import *

+ 64 - 2
tesseract/utils/autograd.py

@@ -1,10 +1,19 @@
+"""
+Temporary autograd extensions to enable inter-op parallelism during backward pass
+Note: we should get rid of this module if https://github.com/pytorch/pytorch/pull/33157 reaches a pytorch release
+"""
+from itertools import chain
 from typing import Tuple, Any
+from concurrent.futures import Future
 
+import numpy as np
 import torch
-from torch.autograd.function import _ContextMethodMixin
+import torch.autograd.function
 
+from .threading import run_in_background
 
-class EmulatedAutogradContext(_ContextMethodMixin):
+
+class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
     """
     A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
     such as running several parallel backwards or transferring backward to a separate device.
@@ -33,3 +42,56 @@ def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradCo
     """
     with torch.no_grad():
         return func.backward(ctx, *grad_outputs)
+
+
+def map_with_parallel_backward(
+        func: torch.autograd.Function, *args_per_call: Tuple[torch.Tensor, ...]) -> Tuple[Tuple[torch.Tensor, ...]]:
+    """
+    Apply an autograd function to several sets of inputs with two extra guarantees:
+    (1) both forward and backward pass happens concurrently for each set of inputs
+    (2) any operation dependent on any individual function will wait for all functions to finish
+    :param func: torch autograd function to be called several times in parallel
+    :param args_per_call: a sequence of tuples of arguments, each tuple corresponds to one function call
+    :returns: a tuple of outputs from each func call
+
+    Note: this function currently requires that all :func: calls succeed (i.e. do not raise an exception).
+    """
+    arg_counts = list(map(len, args_per_call))
+    assert len(set(arg_counts)) == 1, "All input sets must have the same number of arguments"
+    output_strides_ph = Future()
+    flat_outputs: Tuple[torch.Tensor, ...] = _ParallelApplyFunction.apply(
+        func, len(args_per_call), arg_counts[0], output_strides_ph, *chain(*args_per_call))
+    output_strides = output_strides_ph.result()
+    return tuple(flat_outputs[output_strides[i]: output_strides[i + 1]] for i in range(len(output_strides) - 1))
+
+
+class _ParallelApplyFunction(torch.autograd.Function):
+    """
+    A special torch autograd function that runs another function several times in parallel.
+    Please do not call this function directly. Use apply_with_parallel_backward instead.
+    Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
+    """
+    @staticmethod
+    def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
+                output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:
+        assert num_calls * num_args_per_call == len(args_flat)
+        args_per_call = [args_flat[i * num_args_per_call: (i + 1) * num_args_per_call] for i in range(num_calls)]
+
+        futures = [run_in_background(run_isolated_backward, func, *args) for args in args_per_call]
+
+        outputs, contexts = zip(*[future.result() for future in futures])
+        output_strides = np.cumsum([0] + list(map(len, outputs)))
+        ctx._inner_func = func
+        ctx._call_contexts = contexts
+        ctx._output_strides = output_strides
+        output_strides_ph.set_result(output_strides)
+        return tuple(chain(*outputs))
+
+    @staticmethod
+    def backward(ctx, *grad_outputs_flat: torch.Tensor):
+        func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
+        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
+        futures = [run_in_background(run_isolated_backward(func, context, *grads))
+                   for context, grads in zip(contexts, grad_outputs_per_call)]
+        flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
+        return None, None, None, None, *flat_grads_wrt_input