justheuristic 4 年之前
父节点
当前提交
9172a9cd16
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      hivemind/moe/client/balanced_expert.py

+ 1 - 1
hivemind/moe/client/balanced_expert.py

@@ -62,7 +62,7 @@ class BalancedRemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
-        flat_inputs = nested_flatten(forward_inputs)
+        flat_inputs = list(nested_flatten(forward_inputs))
         forward_task_size = flat_inputs[0].shape[0]
         forward_task_size = flat_inputs[0].shape[0]
 
 
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad