|
@@ -62,7 +62,7 @@ class BalancedRemoteExpert(nn.Module):
|
|
if not nested_compare(forward_inputs, self.info["forward_schema"]):
|
|
if not nested_compare(forward_inputs, self.info["forward_schema"]):
|
|
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
|
|
|
|
|
- flat_inputs = nested_flatten(forward_inputs)
|
|
|
|
|
|
+ flat_inputs = list(nested_flatten(forward_inputs))
|
|
forward_task_size = flat_inputs[0].shape[0]
|
|
forward_task_size = flat_inputs[0].shape[0]
|
|
|
|
|
|
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|
|
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|