|
@@ -28,7 +28,6 @@ class BalancedRemoteExpert(nn.Module):
|
|
|
grid_size: Tuple[int, ...],
|
|
|
forward_timeout: Optional[float] = None,
|
|
|
backward_timeout: Optional[float] = None,
|
|
|
- detect_anomalies: bool = False,
|
|
|
update_period: float = 30.0,
|
|
|
backward_task_size_multiplier: float = 2.5,
|
|
|
**kwargs,
|
|
@@ -39,7 +38,7 @@ class BalancedRemoteExpert(nn.Module):
|
|
|
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
|
|
|
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
|
- self.backward_task_size_multiplier, self.detect_anomalies = backward_task_size_multiplier, detect_anomalies
|
|
|
+ self.backward_task_size_multiplier = backward_task_size_multiplier
|
|
|
self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
|
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
|