浏览代码

min_matchmaking_time

justheuristic 3 年之前
父节点
当前提交
413fde64c3
共有 1 个文件被更改,包括 12 次插入10 次删除
  1. 12 10
      hivemind/optim/experimental/optimizer.py

+ 12 - 10
hivemind/optim/experimental/optimizer.py

@@ -273,7 +273,6 @@ class Optimizer(torch.optim.Optimizer):
         return TrainingStateAverager(
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
-            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
@@ -291,7 +290,6 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
-            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
@@ -465,7 +463,11 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_round=should_average_state,
                 delay_averaging=self.delay_state_averaging and not self.auxiliary,
                 averaging_control=self.scheduled_state if should_average_state else None,
-                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+                averaging_opts=dict(
+                    scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
+                )
+                if should_average_state
+                else None,
             )
 
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
@@ -539,9 +541,11 @@ class Optimizer(torch.optim.Optimizer):
                     self.state_averager.step(wait_for_delayed_updates=True)
 
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
-                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                eta_seconds = max(eta_seconds, self.matchmaking_time)
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
-                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+                self.scheduled_grads = self.grad_averager.schedule_step(
+                    scheduled_time=get_dht_time() + eta_seconds, timeout=self.averaging_timeout
+                )
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
@@ -558,12 +562,10 @@ class Optimizer(torch.optim.Optimizer):
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
-
-                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
-                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
-                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
+                eta_seconds = max(eta_seconds_to_averaging, self.matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {eta_seconds:.2f}s.")
                 self.scheduled_state = self.state_averager.schedule_step(
-                    gather=next_epoch, timeout=self.averaging_timeout
+                    scheduled_time=get_dht_time() + eta_seconds, gather=next_epoch, timeout=self.averaging_timeout
                 )
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):