justheuristic пре 4 година
родитељ
комит
9172a9cd16
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      hivemind/moe/client/balanced_expert.py

+ 1 - 1
hivemind/moe/client/balanced_expert.py

@@ -62,7 +62,7 @@ class BalancedRemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
-        flat_inputs = nested_flatten(forward_inputs)
+        flat_inputs = list(nested_flatten(forward_inputs))
         forward_task_size = flat_inputs[0].shape[0]
 
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad