Просмотр исходного кода

wip: implement grad wrt logits

justheuristic 5 лет назад
Родитель
Сommit
69cd699de0
1 измененных файлов с 0 добавлено и 2 удалено
  1. 0 2
      tesseract/client/moe.py

+ 0 - 2
tesseract/client/moe.py

@@ -228,13 +228,11 @@ class _RemoteMoECall(torch.autograd.Function):
                                  for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
 
         # compute grad w.r.t. logits
-        print('A')
         grad_wrt_probs = sum(tuple(
             torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
                       dim=tuple(range(1, stacked_avive_out.ndim)))
             for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
         ))
-        print('B')
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
         grad_wrt_logits = grad_wrt_probs @ softmax_jacobian