justheuristic 5 tahun lalu
induk
melakukan
3d4132635a
1 mengubah file dengan 5 tambahan dan 4 penghapusan
  1. 5 4
      tesseract/client/moe.py

+ 5 - 4
tesseract/client/moe.py

@@ -199,11 +199,12 @@ class _RemoteMoECall(torch.autograd.Function):
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
         stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
-        # flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
-        #                              for stacked_out in stacked_alive_outputs)
-        flat_average_outputs = tuple(map(
+        flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
+                                     for stacked_out in stacked_alive_outputs)
+        flat_average_outputs_ = tuple(map(
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
-
+        assert torch.allcloce(flat_average_outputs_[0], flat_average_outputs[0])
+        assert False
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
         ctx._alive_contexts = alive_contexts