瀏覽代碼

wip: implement grad wrt logits

justheuristic 5 年之前
父節點
當前提交
70149eafc7
共有 1 個文件被更改,包括 4 次插入0 次删除
  1. 4 0
      tesseract/client/moe.py

+ 4 - 0
tesseract/client/moe.py

@@ -219,14 +219,18 @@ class _RemoteMoECall(torch.autograd.Function):
         jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
+        print('A')
         backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
         backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
         backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
+        print('B')
         survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
         weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
         flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
                                  for stacked_grad_inp in map(torch.stack, survived_grad_inputs))
 
+        print('C')
+
         # compute grad w.r.t. logits
         grad_wrt_probs = sum(tuple(
             torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],