justheuristic 4 years ago
parent
commit
9172a9cd16
1 changed files with 1 additions and 1 deletions
  1. 1 1
      hivemind/moe/client/balanced_expert.py

+ 1 - 1
hivemind/moe/client/balanced_expert.py

@@ -62,7 +62,7 @@ class BalancedRemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
-        flat_inputs = nested_flatten(forward_inputs)
+        flat_inputs = list(nested_flatten(forward_inputs))
         forward_task_size = flat_inputs[0].shape[0]
         forward_task_size = flat_inputs[0].shape[0]
 
 
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad