Explorar o código

[final] adjust state scheduling

justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
4b915eea11
Modificáronse 1 ficheiros con 17 adicións e 2 borrados
  1. 17 2
      hivemind/optim/experimental/optimizer.py

+ 17 - 2
hivemind/optim/experimental/optimizer.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import logging
 import os
+import time
 from functools import partial
 from typing import Callable, Optional, Sequence, Union
 
@@ -22,7 +23,7 @@ from hivemind.optim.experimental.state_averager import (
     TrainingStateAverager,
 )
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger, PerformanceEMA
 
 logger = get_logger(__name__)
 
@@ -113,6 +114,7 @@ class Optimizer(torch.optim.Optimizer):
 
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
+    :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
 
     Internally, hivemind.Optimizer consists of 4 components:
@@ -153,6 +155,7 @@ class Optimizer(torch.optim.Optimizer):
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
+        performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
         verbose: bool = False,
     ):
@@ -193,7 +196,9 @@ class Optimizer(torch.optim.Optimizer):
         self.scheduled_grads: Optional[StepControl] = None
         self.scheduled_state: Optional[StepControl] = None
 
-        self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
+        self.tracker = self._make_progress_tracker(
+            target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
+        )
         self.state_averager = self._make_state_averager(
             optimizer=optimizer,
             params=params,
@@ -202,6 +207,7 @@ class Optimizer(torch.optim.Optimizer):
             compression=state_averaging_compression,
             state_compression=load_state_compression,
             average_opt_statistics=average_opt_statistics,
+            performance_ema_alpha=performance_ema_alpha,
             extra_tensors=extra_tensors,
             **averager_opts or {},
         )
@@ -216,6 +222,10 @@ class Optimizer(torch.optim.Optimizer):
         self._schema_hash = self._compute_schema_hash()
         self._parent_pid = os.getpid()
 
+        self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
+        # used for pre-scheduling the averaging round in state_averager
+
         self._step_supports_amp_scaling = reuse_grad_buffers
         # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
         # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
@@ -363,6 +373,7 @@ class Optimizer(torch.optim.Optimizer):
     def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
         """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
         assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+        _epoch_start_time = time.perf_counter()
 
         with self.tracker.pause_updates():
             wait_for_trigger = None
@@ -384,6 +395,9 @@ class Optimizer(torch.optim.Optimizer):
             should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
             should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
 
+            if should_average_state:
+                self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
+
             self.state_averager.step(
                 increment_epoch=True,
                 wait_for_trigger=wait_for_trigger,
@@ -477,6 +491,7 @@ class Optimizer(torch.optim.Optimizer):
             return  # averaging is not performed at this epoch
 
         estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
         estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
         eta_seconds_to_averaging = self.tracker.estimated_next_update_time - get_dht_time()