ソースを参照

wip: implement grad wrt logits

justheuristic 5 年 前
コミット
2146fb6d0e
1 ファイル変更1 行追加1 行削除
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -224,7 +224,6 @@ class _RemoteMoECall(torch.autograd.Function):
         backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
         survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
         weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
-        print('>>>', survived_grad_inputs)
         flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
                                  for stacked_grad_inp in map(torch.stack, survived_grad_inputs))
 
@@ -247,6 +246,7 @@ class _RemoteMoECall(torch.autograd.Function):
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
+        print('!!!', [g.shape for g in grad_outputs])
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
         grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
         return grad_inputs