Prechádzať zdrojové kódy

wip: implement grad wrt logits

justheuristic 5 rokov pred
rodič
commit
b79d05e037
1 zmenil súbory, kde vykonal 1 pridanie a 1 odobranie
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -224,7 +224,7 @@ class _RemoteMoECall(torch.autograd.Function):
         backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
         survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
         weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
-
+        print('>>>', survived_grad_inputs)
         flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
                                  for stacked_grad_inp in map(torch.stack, survived_grad_inputs))