|
@@ -273,6 +273,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return TrainingStateAverager(
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_state_averager",
|
|
|
+ min_matchmaking_time=self.matchmaking_time,
|
|
|
allreduce_timeout=self.allreduce_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
offload_optimizer=self.offload_optimizer,
|
|
@@ -290,6 +291,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
|
parameters=self.state_averager.main_parameters,
|
|
|
+ min_matchmaking_time=self.matchmaking_time,
|
|
|
allreduce_timeout=self.allreduce_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
client_mode=self.client_mode,
|
|
@@ -463,11 +465,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_round=should_average_state,
|
|
|
delay_averaging=self.delay_state_averaging and not self.auxiliary,
|
|
|
averaging_control=self.scheduled_state if should_average_state else None,
|
|
|
- averaging_opts=dict(
|
|
|
- scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
- )
|
|
|
- if should_average_state
|
|
|
- else None,
|
|
|
+ averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
|
|
|
)
|
|
|
|
|
|
if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
|
|
@@ -541,11 +539,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
|
|
|
eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
|
|
|
- eta_seconds = max(eta_seconds, self.matchmaking_time)
|
|
|
+ eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
|
|
|
logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
|
|
|
- self.scheduled_grads = self.grad_averager.schedule_step(
|
|
|
- scheduled_time=get_dht_time() + eta_seconds, timeout=self.averaging_timeout
|
|
|
- )
|
|
|
+ self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
|
|
|
|
|
|
def _maybe_schedule_state_averaging(self) -> None:
|
|
|
"""If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
|
|
@@ -562,10 +558,12 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
if eta_seconds_to_averaging <= self.matchmaking_time:
|
|
|
if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
|
|
|
- eta_seconds = max(eta_seconds_to_averaging, self.matchmaking_time)
|
|
|
- logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {eta_seconds:.2f}s.")
|
|
|
+
|
|
|
+ min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
|
|
|
+ actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
|
|
|
+ logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
|
|
|
self.scheduled_state = self.state_averager.schedule_step(
|
|
|
- scheduled_time=get_dht_time() + eta_seconds, gather=next_epoch, timeout=self.averaging_timeout
|
|
|
+ gather=next_epoch, timeout=self.averaging_timeout
|
|
|
)
|
|
|
|
|
|
def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
|