Jelajahi Sumber

wip: parallel fault-tolerant moe backward pass

justheuristic 5 tahun lalu
induk
melakukan
6fb99c8746

+ 2 - 0
tesseract/client/expert.py

@@ -2,6 +2,7 @@ from typing import Tuple, Optional
 
 import torch
 import torch.nn as nn
+from torch.autograd.function import once_differentiable
 
 from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
 
@@ -69,6 +70,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         return tuple(PytorchSerializer.loads(msg))  # flattened expert outputs
 
     @staticmethod
+    @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         connection = Connection.create(ctx.host, ctx.port)
         payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))

+ 81 - 27
tesseract/client/moe.py

@@ -1,14 +1,19 @@
 import multiprocessing as mp
 import multiprocessing.pool
-from concurrent.futures import as_completed
-from typing import Tuple, List, Dict, Any
+from concurrent.futures import as_completed, TimeoutError, Future
+from functools import partial
+from itertools import chain
+from typing import Tuple, List, Dict, Any, Optional
 
 import numpy as np
 import torch
 import torch.nn as nn
+from torch.autograd.function import once_differentiable
 
-from .expert import RemoteExpert
-from ..utils import nested_map, check_numpy, run_in_background
+from .expert import RemoteExpert, _RemoteModuleCall
+from ..utils import nested_map, check_numpy, run_in_background, run_and_await_k, nested_pack, BatchTensorProto, \
+    nested_flatten, DUMMY
+from ..utils.autograd import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward
 
 
 class RemoteMixtureOfExperts(nn.Module):
@@ -140,29 +145,7 @@ class RemoteMixtureOfExperts(nn.Module):
             uid for row in beam for uid in row if uid != self.expert_padding)))
         unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
 
-        return [
-            [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
-            for row in beam]
-
-    def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
-        future_to_expert = {run_in_background(expert, *args, **kwargs): expert for expert in experts}
-        pending_futures = set(future_to_expert.keys())
-        outputs = {}  # {expert -> output}
-
-        # await first k futures for as long as it takes
-        for future in as_completed(list(pending_futures), timeout=None):
-            pending_futures.remove(future)
-            if not future.exception():
-                outputs[future_to_expert.pop(future)] = future.result()
-                if len(outputs) > self.k_min:
-                    break
-
-        # await stragglers for at most self.timeout_after_k_min
-        for future in as_completed(pending_futures, timeout=self.timeout_after_k_min):
-            if not future.exception():
-                outputs[future_to_expert.pop(future)] = future.result()
-
-        return outputs
+        return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
 
     def _score_experts(self, grid_scores: List[torch.Tensor],
                        experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
@@ -186,3 +169,74 @@ class RemoteMixtureOfExperts(nn.Module):
             output_dicts[batch_i][expert] = score
 
         return output_dicts
+
+
+class _RemoteMoECall(torch.autograd.Function):
+    """
+    Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
+    This function that can recover from individual failures during forward and/or backward passes.
+    For user-friendly version of this function, use RemoteMixtureOfExperts module.
+    """
+    MIN_TOTAL_WEIGHT = 1e-3
+
+    @classmethod
+    def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
+                *flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
+                timeout_total: Optional[float], backward_timeout: Optional[float]) -> Tuple[torch.Tensor]:
+        expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
+        assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
+
+        # 1. call experts and await results
+        jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
+        results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
+
+        alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
+                                                        if not isinstance(result, BaseException)])
+        #     ^               ^            ^-- a list of indices of experts that returned outputs in time
+        #      \               \-- list of outputs of every expert that didn't die on us
+        #       \-- a list of autograd contexts, used for parallel backward
+
+        # 2. compute softmax weights for alive experts and average outputs
+        alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
+
+        flat_average_outputs = tuple(map(
+            lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
+
+        # 3. save individual outputs for backward pass
+        ctx.save_for_backward(flat_inputs, expert_logits, alive_ix, alive_expert_probs)
+        ctx._alive_contexts = alive_contexts
+        ctx._backward_k_min = backward_k_min
+        ctx._backward_timeout = backward_timeout
+        return tuple(map(torch.Tensor.detach, flat_average_outputs))
+
+    @classmethod
+    @once_differentiable
+    def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
+        """ Like normal backward, but we ignore any experts that failed during backward pass """
+        flat_inputs, expert_logits, alive_ix, alive_expert_probas  = ctx.saved_tensors
+        alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
+
+        jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)
+                for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
+        results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
+        survived_backward, survived_grad_inputs = zip(*(alive_ix[i], grads for i, grads in enumerate(results)))
+
+        survived_ix = alive_ix[survived_backward]
+        survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
+
+        flat_grad_inputs = tuple(map(
+            lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),
+            *survived_grad_inputs))
+
+        grad_logits = None  # TODO
+        return (grad_logits, None, *flat_grad_inputs, None, None, None, None, None, None)
+
+    @staticmethod
+    def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
+        """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
+        flat_inputs = nested_flatten((args, kwargs))
+        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, *flat_inputs)
+
+    @staticmethod
+    def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
+        return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))

+ 36 - 0
tesseract/utils/autograd.py

@@ -0,0 +1,36 @@
+from typing import Tuple, Any
+
+import torch
+from torch.autograd.function import _ContextMethodMixin
+
+
+class EmulatedAutogradContext(_ContextMethodMixin):
+    """
+    A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
+    such as running several parallel backwards or transferring backward to a separate device.
+    This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
+    """
+    @property
+    def saved_tensors(self):
+        return tuple(self.to_save)
+
+
+def run_isolated_forward(func: torch.autograd.Function, *args, **kwargs) -> Tuple[EmulatedAutogradContext, Any]:
+    """
+    run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
+    can be used to run backward through the same graph (manually by the user).
+    """
+    ctx = EmulatedAutogradContext()
+    # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
+    args = tuple(x.detach().requires_grad_(x.requires_grad) for x in args)
+    kwargs = {k: x.detach().requires_grad_(x.requires_grad) for k, x in kwargs.items()}
+    with torch.no_grad():
+        return ctx, func.forward(ctx, *args, **kwargs)
+
+
+def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
+    """
+    run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
+    """
+    with torch.no_grad():
+        return func.backward(ctx, *grad_outputs)

+ 50 - 1
tesseract/utils/threading.py

@@ -1,5 +1,7 @@
-from concurrent.futures import Future
+from concurrent.futures import Future, as_completed
+import time
 from threading import Thread
+from typing import Optional, List
 
 
 def run_in_background(func: callable, *args, **kwargs) -> Future:
@@ -22,3 +24,50 @@ def run_forever(func: callable, *args, **kwargs):
         while True:
             func(*args, **kwargs)
     return run_in_background(repeat)
+
+
+def run_and_await_k(jobs: List[callable], k: int,
+                    timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
+    """
+    Runs all :jobs: asynchronously, awaits for at least k of them to finish
+    :param jobs: functions to call asynchronously
+    :param k: how many functions should finish for call to be successful
+    :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
+    :param timeout_total: if specified, terminate cancel jobs after this many seconds
+    :returns: a list of either results or exceptions for each job
+    """
+    jobs = list(jobs)
+    assert k <= len(jobs), f"Can't await {k} out of {len(jobs)} jobs."
+    start_time = time.time()
+    future_to_ix = {run_in_background(job): i for i, job in jobs}
+    outputs = [None] * len(jobs)
+    success_count = 0
+
+    try:
+        # await first k futures for as long as it takes
+        for future in as_completed(list(future_to_ix.keys()), timeout=timeout_total):
+            success_count += int(not future.exception())
+            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
+            if success_count >= k:
+                break  # we have enough futures to succeed
+            if len(outputs) + len(future_to_ix) < k:
+                failed = len(jobs) - len(outputs) - len(future_to_ix)
+                raise ValueError(f"Couldn't get enough results: too many jobs failed ({failed} / {len(outputs)})")
+
+        # await stragglers for at most self.timeout_after_k_min or whatever time is left
+        if timeout_after_k is not None and timeout_total is not None:
+            time_left = min(timeout_after_k, timeout_total - time.time() + start_time)
+        else:
+            time_left = timeout_after_k if timeout_after_k is not None else timeout_total
+        for future in as_completed(list(future_to_ix.keys()), timeout=time_left):
+            success_count += int(not future.exception())
+            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
+
+    except TimeoutError:
+        if len(outputs) < k:
+            raise TimeoutError(f"Couldn't get enough results: time limit exceeded (got {len(outputs)} of {k})")
+    finally:
+        for future, index in future_to_ix.items():
+            future.cancel()
+            outputs[index] = future.result() if not future.exception() else future.exception()
+    return outputs