Prechádzať zdrojové kódy

reweigh grads correctly

justheuristic 5 rokov pred
rodič
commit
662357fcb3
1 zmenil súbory, kde vykonal 2 pridanie a 1 odobranie
  1. 2 1
      tesseract/client/moe.py

+ 2 - 1
tesseract/client/moe.py

@@ -225,7 +225,8 @@ class _RemoteMoECall(torch.autograd.Function):
         survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
 
         flat_grad_inputs = tuple(map(
-            lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),
+            lambda *tensors: sum(x * (weight / old_weight) for x, weight, old_weight
+                                 in zip(tensors, survived_expert_probas, alive_expert_probas[survived_backward])),
             *survived_grad_inputs))
 
         grad_logits = None  # TODO