justheuristic 5 anni fa
parent
commit
4723ce77b1
1 ha cambiato i file con 0 aggiunte e 4 eliminazioni
  1. 0 4
      tesseract/client/moe.py

+ 0 - 4
tesseract/client/moe.py

@@ -201,10 +201,6 @@ class _RemoteMoECall(torch.autograd.Function):
         stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
-        flat_average_outputs_ = tuple(map(
-            lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
-        assert torch.allclose(flat_average_outputs_[0], flat_average_outputs[0])
-        assert False
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
         ctx._alive_contexts = alive_contexts