Selaa lähdekoodia

wip: implement grad wrt logits

justheuristic 5 vuotta sitten
vanhempi
commit
3b8a104bfa
1 muutettua tiedostoa jossa 2 lisäystä ja 3 poistoa
  1. 2 3
      tesseract/client/moe.py

+ 2 - 3
tesseract/client/moe.py

@@ -198,10 +198,9 @@ class _RemoteMoECall(torch.autograd.Function):
         alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
-        stacked_alive_outputs = tuple(map(torch.stack, list(zip(*alive_outputs))))
+        stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
-        print('!' * 50, [x.shape for x in flat_average_outputs])
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
@@ -232,7 +231,7 @@ class _RemoteMoECall(torch.autograd.Function):
         grad_wrt_probs = sum(tuple(
             torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
                       dim=tuple(range(1, stacked_avive_out.ndim)))
-            for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
+            for grad_out, stacked_avive_out in zip(grad_outputs_flat, zip(*stacked_alive_outputs))
         ))
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
         grad_wrt_logits = grad_wrt_probs @ softmax_jacobian