|
@@ -198,11 +198,12 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
|
|
|
alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
|
|
|
|
|
|
- flat_average_outputs = tuple(map(
|
|
|
- lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
|
|
|
+ stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
|
|
|
+ flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
+ for stacked_out in stacked_alive_outputs)
|
|
|
|
|
|
# 3. save individual outputs for backward pass
|
|
|
- ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs)
|
|
|
+ ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|
|
|
ctx._alive_contexts = alive_contexts
|
|
|
ctx._backward_k_min = backward_k_min
|
|
|
ctx._backward_timeout = backward_timeout
|
|
@@ -212,24 +213,31 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
@once_differentiable
|
|
|
def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
""" Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
- expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
|
+ expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
|
|
|
alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
|
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
|
results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
|
- survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
|
|
|
- survived_backward = torch.as_tensor(survived_backward, device=expert_logits.device)
|
|
|
- survived_ix = alive_ix[survived_backward]
|
|
|
- survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
|
|
|
-
|
|
|
- flat_grad_inputs = tuple(map(
|
|
|
- lambda *tensors: sum(x * (weight / old_weight) for x, weight, old_weight
|
|
|
- in zip(tensors, survived_expert_probas, alive_expert_probas[survived_backward])),
|
|
|
- *survived_grad_inputs))
|
|
|
-
|
|
|
- grad_logits = None # TODO
|
|
|
- return grad_logits, None, None, None, None, None, None, None, *flat_grad_inputs
|
|
|
+ backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
|
|
|
+ backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
|
|
|
+ backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
|
+ survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
|
|
|
+ weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
|
|
|
+
|
|
|
+ flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
|
|
|
+ for stacked_grad_inp in map(torch.stack, survived_grad_inputs))
|
|
|
+
|
|
|
+ # compute grad w.r.t. logits
|
|
|
+ grad_wrt_probs = sum(tuple(
|
|
|
+ torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
|
|
|
+ dim=tuple(range(1, stacked_avive_out.ndim)))
|
|
|
+ for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
|
|
|
+ ))
|
|
|
+ softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
|
|
|
+ grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
|
|
|
+
|
|
|
+ return grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs
|
|
|
|
|
|
@staticmethod
|
|
|
def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
@@ -242,3 +250,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
|
|
|
return grad_inputs
|
|
|
+
|
|
|
+
|
|
|
+def dot_along_first_axis(x, y):
|
|
|
+ (x.view(-1, *[1] * (y.ndim - 1))).sum(0)
|