浏览代码

wip: parallel fault-tolerant moe backward pass

justheuristic 5 年之前
父节点
当前提交
2b2ddf8280
共有 1 个文件被更改,包括 1 次插入3 次删除
  1. 1 3
      tesseract/client/moe.py

+ 1 - 3
tesseract/client/moe.py

@@ -177,8 +177,6 @@ class _RemoteMoECall(torch.autograd.Function):
     This function that can recover from individual failures during forward and/or backward passes.
     This function that can recover from individual failures during forward and/or backward passes.
     For user-friendly version of this function, use RemoteMixtureOfExperts module.
     For user-friendly version of this function, use RemoteMixtureOfExperts module.
     """
     """
-    MIN_TOTAL_WEIGHT = 1e-3
-
     @classmethod
     @classmethod
     def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
     def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
                 *flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
                 *flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
@@ -229,7 +227,7 @@ class _RemoteMoECall(torch.autograd.Function):
             *survived_grad_inputs))
             *survived_grad_inputs))
 
 
         grad_logits = None  # TODO
         grad_logits = None  # TODO
-        return (grad_logits, None, *flat_grad_inputs, None, None, None, None, None, None)
+        return grad_logits, None, *flat_grad_inputs, None, None, None, None, None, None
 
 
     @staticmethod
     @staticmethod
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):