|
@@ -36,7 +36,7 @@ class BalancedRemoteExpert(nn.Module):
|
|
|
super().__init__()
|
|
|
if uid_prefix.endswith(".0."):
|
|
|
logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
|
|
|
- assert len(grid_size) == 2 and grid_size[0] == 0, "only 1xN grids are supported"
|
|
|
+ assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
|
|
|
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
|