Selaa lähdekoodia

min_matchmaking_time

justheuristic 3 vuotta sitten
vanhempi
commit
413fde64c3
1 muutettua tiedostoa jossa 12 lisäystä ja 10 poistoa
  1. 12 10
      hivemind/optim/experimental/optimizer.py

+ 12 - 10
hivemind/optim/experimental/optimizer.py

@@ -273,7 +273,6 @@ class Optimizer(torch.optim.Optimizer):
         return TrainingStateAverager(
         return TrainingStateAverager(
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
             prefix=f"{self.run_id}_state_averager",
-            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
             offload_optimizer=self.offload_optimizer,
@@ -291,7 +290,6 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
-            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
@@ -465,7 +463,11 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_round=should_average_state,
                 averaging_round=should_average_state,
                 delay_averaging=self.delay_state_averaging and not self.auxiliary,
                 delay_averaging=self.delay_state_averaging and not self.auxiliary,
                 averaging_control=self.scheduled_state if should_average_state else None,
                 averaging_control=self.scheduled_state if should_average_state else None,
-                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+                averaging_opts=dict(
+                    scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
+                )
+                if should_average_state
+                else None,
             )
             )
 
 
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
@@ -539,9 +541,11 @@ class Optimizer(torch.optim.Optimizer):
                     self.state_averager.step(wait_for_delayed_updates=True)
                     self.state_averager.step(wait_for_delayed_updates=True)
 
 
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
-                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                eta_seconds = max(eta_seconds, self.matchmaking_time)
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
-                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+                self.scheduled_grads = self.grad_averager.schedule_step(
+                    scheduled_time=get_dht_time() + eta_seconds, timeout=self.averaging_timeout
+                )
 
 
     def _maybe_schedule_state_averaging(self) -> None:
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
@@ -558,12 +562,10 @@ class Optimizer(torch.optim.Optimizer):
 
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
         if eta_seconds_to_averaging <= self.matchmaking_time:
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
-
-                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
-                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
-                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
+                eta_seconds = max(eta_seconds_to_averaging, self.matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {eta_seconds:.2f}s.")
                 self.scheduled_state = self.state_averager.schedule_step(
                 self.scheduled_state = self.state_averager.schedule_step(
-                    gather=next_epoch, timeout=self.averaging_timeout
+                    scheduled_time=get_dht_time() + eta_seconds, gather=next_epoch, timeout=self.averaging_timeout
                 )
                 )
 
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):