|
@@ -34,12 +34,12 @@ class BalancedRemoteExpert(nn.Module):
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
if uid_prefix.endswith(".0."):
|
|
if uid_prefix.endswith(".0."):
|
|
- logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}.0.")
|
|
|
|
|
|
+ logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
|
|
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
|
|
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
|
|
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
|
|
self.backward_task_size_multiplier = backward_task_size_multiplier
|
|
self.backward_task_size_multiplier = backward_task_size_multiplier
|
|
- self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}.0.", update_period=update_period, **kwargs)
|
|
|
|
|
|
+ self.expert_balancer = ExpertBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
self._expert_info = None # expert['info'] from one of experts in the grid
|
|
|
|
|
|
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
|