Sfoglia il codice sorgente

wip: implement grad wrt logits

justheuristic 5 anni fa
parent
commit
e8ee28a392
1 ha cambiato i file con 3 aggiunte e 3 eliminazioni
  1. 3 3
      tesseract/client/moe.py

+ 3 - 3
tesseract/client/moe.py

@@ -199,8 +199,8 @@ class _RemoteMoECall(torch.autograd.Function):
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
-        flat_average_outputs = sum(dot_along_first_axis(alive_expert_probs, stacked_out)
-                                   for stacked_out in stacked_alive_outputs)
+        flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
+                                     for stacked_out in stacked_alive_outputs)
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
@@ -215,7 +215,7 @@ class _RemoteMoECall(torch.autograd.Function):
         """ Like normal backward, but we ignore any experts that failed during backward pass """
         expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
         alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
-        print(grad_outputs_flat)
+
         jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)