소스 검색

wip: implement grad wrt logits

justheuristic 5 년 전
부모
커밋
b22dcf509c
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -198,7 +198,7 @@ class _RemoteMoECall(torch.autograd.Function):
         alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
-        stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
+        stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
         print(flat_average_outputs)