浏览代码

restore py37

justheuristic 5 年之前
父节点
当前提交
8b162e79f7
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      hivemind/client/moe.py

+ 1 - 1
hivemind/client/moe.py

@@ -246,7 +246,7 @@ class _RemoteMoECall(torch.autograd.Function):
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
         grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
         grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
 
 
-        return grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs
+        return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
 
 
     @staticmethod
     @staticmethod
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):