|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
+import time
|
|
|
from functools import partial
|
|
|
from typing import Callable, Optional, Sequence, Union
|
|
|
|
|
@@ -22,7 +23,7 @@ from hivemind.optim.experimental.state_averager import (
|
|
|
TrainingStateAverager,
|
|
|
)
|
|
|
from hivemind.optim.grad_scaler import GradScaler
|
|
|
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
|
|
|
+from hivemind.utils import get_dht_time, get_logger, PerformanceEMA
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -113,6 +114,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
:param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
|
|
|
:param tracker_opts: additional keyword arguments forwarded to ProgressTracker
|
|
|
+ :param performance_ema_alpha: moving average alpha in ProgressTracer, TrainingStateAverager and Optimizer
|
|
|
:param verbose: if True, report internal events such as accumilating gradients and running background tasks
|
|
|
|
|
|
Internally, hivemind.Optimizer consists of 4 components:
|
|
@@ -153,6 +155,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
averager_opts: Optional[dict] = None,
|
|
|
tracker_opts: Optional[dict] = None,
|
|
|
+ performance_ema_alpha: float = 0.1,
|
|
|
shutdown_timeout: float = 5,
|
|
|
verbose: bool = False,
|
|
|
):
|
|
@@ -193,7 +196,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.scheduled_grads: Optional[StepControl] = None
|
|
|
self.scheduled_state: Optional[StepControl] = None
|
|
|
|
|
|
- self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
|
|
|
+ self.tracker = self._make_progress_tracker(
|
|
|
+ target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
|
|
|
+ )
|
|
|
self.state_averager = self._make_state_averager(
|
|
|
optimizer=optimizer,
|
|
|
params=params,
|
|
@@ -202,6 +207,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
compression=state_averaging_compression,
|
|
|
state_compression=load_state_compression,
|
|
|
average_opt_statistics=average_opt_statistics,
|
|
|
+ performance_ema_alpha=performance_ema_alpha,
|
|
|
extra_tensors=extra_tensors,
|
|
|
**averager_opts or {},
|
|
|
)
|
|
@@ -216,6 +222,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self._schema_hash = self._compute_schema_hash()
|
|
|
self._parent_pid = os.getpid()
|
|
|
|
|
|
+ self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
|
|
|
+ # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
|
|
|
+ # used for pre-scheduling the averaging round in state_averager
|
|
|
+
|
|
|
self._step_supports_amp_scaling = reuse_grad_buffers
|
|
|
# note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
|
|
|
# buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
|
|
@@ -363,6 +373,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
|
|
|
"""Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
|
|
|
assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
|
|
|
+ _epoch_start_time = time.perf_counter()
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
wait_for_trigger = None
|
|
@@ -384,6 +395,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
|
|
|
should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
|
|
|
|
|
|
+ if should_average_state:
|
|
|
+ self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
|
|
|
+
|
|
|
self.state_averager.step(
|
|
|
increment_epoch=True,
|
|
|
wait_for_trigger=wait_for_trigger,
|
|
@@ -477,6 +491,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return # averaging is not performed at this epoch
|
|
|
|
|
|
estimated_time = self.tracker.estimated_next_update_time
|
|
|
+ estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
|
|
|
estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
|
|
|
eta_seconds_to_averaging = self.tracker.estimated_next_update_time - get_dht_time()
|
|
|
|