ソースを参照

wip: implement grad wrt logits

justheuristic 5 年 前
コミット
5cc3cd99c3
1 ファイル変更1 行追加2 行削除
  1. 1 2
      tesseract/client/moe.py

+ 1 - 2
tesseract/client/moe.py

@@ -215,7 +215,7 @@ class _RemoteMoECall(torch.autograd.Function):
         """ Like normal backward, but we ignore any experts that failed during backward pass """
         expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
         alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
-
+        print(grad_outputs_flat)
         jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
@@ -246,7 +246,6 @@ class _RemoteMoECall(torch.autograd.Function):
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        print('!!!', [g.shape for g in grad_outputs])
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
         grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
         return grad_inputs