Prechádzať zdrojové kódy

deprecate custom dot

justheuristic 5 rokov pred
rodič
commit
69f5a33b5a
1 zmenil súbory, kde vykonal 2 pridanie a 6 odobranie
  1. 2 6
      tesseract/client/moe.py

+ 2 - 6
tesseract/client/moe.py

@@ -197,7 +197,7 @@ class _RemoteMoECall(torch.autograd.Function):
 
         stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
 
-        flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
+        flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
                                      for stacked_out in stacked_alive_outputs)
 
         # 3. save individual outputs for backward pass
@@ -222,7 +222,7 @@ class _RemoteMoECall(torch.autograd.Function):
         backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
         survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
         weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
-        flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
+        flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
                                  for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
 
         # compute grad w.r.t. logits
@@ -247,7 +247,3 @@ class _RemoteMoECall(torch.autograd.Function):
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
         grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
         return grad_inputs
-
-
-def dot_along_first_axis(x, y):
-    return (x.view(-1, *[1] * (y.ndim - 1)) * y).sum(0)