justheuristic 3 anos atrás
pai
commit
472dfb3095
1 arquivos alterados com 10 adições e 12 exclusões
  1. 10 12
      hivemind/optim/experimental/optimizer.py

+ 10 - 12
hivemind/optim/experimental/optimizer.py

@@ -273,6 +273,7 @@ class Optimizer(torch.optim.Optimizer):
         return TrainingStateAverager(
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
@@ -290,6 +291,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
@@ -463,11 +465,7 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_round=should_average_state,
                 delay_averaging=self.delay_state_averaging and not self.auxiliary,
                 averaging_control=self.scheduled_state if should_average_state else None,
-                averaging_opts=dict(
-                    scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
-                )
-                if should_average_state
-                else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
 
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
@@ -541,11 +539,9 @@ class Optimizer(torch.optim.Optimizer):
                     self.state_averager.step(wait_for_delayed_updates=True)
 
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
-                eta_seconds = max(eta_seconds, self.matchmaking_time)
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
-                self.scheduled_grads = self.grad_averager.schedule_step(
-                    scheduled_time=get_dht_time() + eta_seconds, timeout=self.averaging_timeout
-                )
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
@@ -562,10 +558,12 @@ class Optimizer(torch.optim.Optimizer):
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
-                eta_seconds = max(eta_seconds_to_averaging, self.matchmaking_time)
-                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {eta_seconds:.2f}s.")
+
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
                 self.scheduled_state = self.state_averager.schedule_step(
-                    scheduled_time=get_dht_time() + eta_seconds, gather=next_epoch, timeout=self.averaging_timeout
+                    gather=next_epoch, timeout=self.averaging_timeout
                 )
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):