瀏覽代碼

Add ProgressTracker (#408)

Auxiliary class that keeps track of local & global training progress, measured in epochs.
An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.

Co-authored-by: Anton Sinitsin <ant.sinitsin@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 年之前
父節點
當前提交
d883387e37
共有 2 個文件被更改,包括 430 次插入0 次删除
  1. 321 0
      hivemind/optim/experimental/progress_tracker.py
  2. 109 0
      tests/test_optimizer.py

+ 321 - 0
hivemind/optim/experimental/progress_tracker.py

@@ -0,0 +1,321 @@
+import asyncio
+import contextlib
+import logging
+import threading
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import numpy as np
+from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
+
+from hivemind.dht import DHT
+from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
+from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.performance_ema import PerformanceEMA
+
+logger = get_logger(__name__)
+
+
+@dataclass(frozen=False)
+class GlobalTrainingProgress:
+    epoch: int
+    samples_accumulated: int
+    target_batch_size: int
+    num_peers: int
+    num_clients: int
+    eta_next_epoch: float
+    next_fetch_time: float
+
+
+class LocalTrainingProgress(BaseModel):
+    peer_id: bytes
+    epoch: conint(ge=0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    time: StrictFloat
+    client_mode: StrictBool
+
+
+class TrainingProgressSchema(BaseModel):
+    progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
+
+
+class ProgressTracker(threading.Thread):
+    """
+    Auxiliary class that keeps track of local & global training progress, measured in epochs.
+    An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
+    Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
+
+    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
+    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
+    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
+    :param expected_drift_peers: assume that this many new peers can join between epochs
+    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
+    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
+      refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
+    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
+    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
+
+    Example:
+
+    >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
+    >>> local_epoch, local_samples = 0, 0
+    >>> while True:
+    >>>     accumulate_gradients(batch_size=32)
+    >>>     local_samples += 32
+    >>>     tracker.report_local_progress(local_epoch, local_samples)
+    >>>     if local_epoch < tracker.global_progress.epoch:
+    >>>         download_state_from_peers()  # if peer is out of sync, synchronize it with the swarm
+    >>>     if tracker.accumulated_enough_samples:
+    >>>         with tracker.pause_updates():
+    >>>             aggregate_gradients_with_peers()
+    >>>             update_model_parameters()
+    >>>             local_epoch = tracker.update_epoch(local_epoch + 1)
+    >>>             local_samples = 0
+    """
+
+    def __init__(
+        self,
+        dht: DHT,
+        prefix: str,
+        target_batch_size: int,
+        *,
+        client_mode: Optional[bool] = None,
+        min_refresh_period: float = 0.5,
+        max_refresh_period: float = 30,
+        default_refresh_period: float = 3,
+        expected_drift_peers: float = 3,
+        expected_drift_rate: float = 0.2,
+        performance_ema_alpha: float = 0.1,
+        metadata_expiration: float = 30.0,
+        status_loglevel: int = logging.DEBUG,
+        private_key: Optional[RSAPrivateKey] = None,
+        daemon: bool = True,
+        start: bool,
+    ):
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
+        self.training_progress_key = f"{self.prefix}_progress"
+        self.target_batch_size = target_batch_size
+        self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
+        self.default_refresh_period = default_refresh_period
+        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
+        self.status_loglevel = status_loglevel
+        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
+        self.metadata_expiration = metadata_expiration
+
+        signature_validator = RSASignatureValidator(private_key)
+        self._local_public_key = signature_validator.local_public_key
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
+
+        # report the collaboration progress periodically or in background
+        self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
+        metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
+        self.global_progress = self._parse_swarm_progress_data(metadata)
+        self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
+        self.should_report_progress = threading.Event()
+        self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
+        super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
+        if start:
+            self.start()
+
+    @property
+    def global_epoch(self) -> int:
+        return self.global_progress.epoch
+
+    @property
+    def ready_to_update_epoch(self) -> bool:
+        """Whether or not this peer can increment epoch right away."""
+        return (
+            self.global_epoch > self.local_progress.epoch
+            or self.global_progress.samples_accumulated >= self.target_batch_size
+            or get_dht_time() >= self.global_progress.eta_next_epoch
+        )
+
+    @property
+    def estimated_next_update_time(self) -> DHTExpiration:
+        """Estimate (absolute) time when this peer should increment epoch"""
+        if self.ready_to_update_epoch:
+            return get_dht_time()
+        return self.global_progress.eta_next_epoch
+
+    def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
+        return LocalTrainingProgress(
+            peer_id=self.dht.peer_id.to_bytes(),
+            epoch=local_epoch,
+            samples_accumulated=samples_accumulated,
+            samples_per_second=self.performance_ema.samples_per_second,
+            time=get_dht_time(),
+            client_mode=self.client_mode,
+        )
+
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int):
+        """Update the number of locally accumulated samples and notify to other peers about this."""
+        extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if extra_samples > 0:
+            self.performance_ema.update(task_size=extra_samples)
+            logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
+        else:
+            logger.debug("Resetting performance timestamp to current time (progress was reset)")
+            self.performance_ema.reset_timer()
+        self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
+        self.should_report_progress.set()
+
+    @contextlib.contextmanager
+    def pause_updates(self):
+        """Temporarily stop progress tracker from updating global training state"""
+        with self.lock_global_progress, self.performance_ema.pause():
+            yield
+
+    def update_epoch(self, new_epoch: Optional[int] = None) -> int:
+        """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
+        assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
+        if new_epoch is None:
+            new_epoch = self.local_progress.epoch + 1
+        if new_epoch > self.global_progress.epoch:
+            self.global_progress.epoch = new_epoch
+            self.global_progress.samples_accumulated = 0
+            self.global_progress.eta_next_epoch = float("inf")
+        self.report_local_progress(new_epoch, samples_accumulated=0)
+        return new_epoch
+
+    def run(self):
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
+        self.shutdown_complete.set()
+
+    async def _progress_reporter(self):
+        """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
+        last_report_time = -float("inf")
+        try:
+            while not self.shutdown_triggered.is_set():
+                wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
+                logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
+                await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
+                if self.should_report_progress.is_set():
+                    logger.debug(f"Progress update triggered by report_local_progress.")
+                    self.should_report_progress.clear()
+                else:
+                    logger.debug(f"Progress update triggered by metadata_expiration.")
+
+                local_progress = self.local_progress
+                last_report_time = get_dht_time()
+
+                await self.dht.store(
+                    key=self.training_progress_key,
+                    subkey=self._local_public_key,
+                    value=local_progress.dict(),
+                    expiration_time=last_report_time + self.metadata_expiration,
+                    return_future=True,
+                )
+        finally:
+            logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
+
+    async def _progress_fetcher(self):
+        """
+        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
+        """
+        loop = asyncio.get_event_loop()
+        try:
+            while not self.shutdown_triggered.is_set():
+                time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
+                state_updated_externally = await loop.run_in_executor(
+                    None, self.global_state_updated.wait, time_to_next_update
+                )
+                if state_updated_externally:
+                    self.global_state_updated.clear()
+                    continue
+
+                async with enter_asynchronously(self.lock_global_progress):
+                    progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True)
+                    metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None
+                    self.global_progress = self._parse_swarm_progress_data(metadata)
+        finally:
+            logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
+
+    def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
+        """Read performance statistics reported by peers, estimate progress towards next batch"""
+        current_time = get_dht_time()
+
+        if not isinstance(metadata, dict) or len(metadata) == 0:
+            logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
+            samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
+            local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
+
+            return GlobalTrainingProgress(
+                self.local_progress.epoch,
+                self.local_progress.samples_accumulated,
+                self.target_batch_size,
+                num_peers=0,
+                num_clients=0,
+                eta_next_epoch=current_time + local_eta_next_epoch,
+                next_fetch_time=current_time + self.default_refresh_period,
+            )
+
+        valid_peer_entries = [
+            LocalTrainingProgress.parse_obj(peer_state.value)
+            for peer_state in metadata.values()
+            if peer_state.value is not None
+        ]
+
+        num_peers = len(valid_peer_entries)
+        num_clients = sum(peer.client_mode for peer in valid_peer_entries)
+
+        global_epoch = self.local_progress.epoch
+        for peer in valid_peer_entries:
+            if not peer.client_mode:
+                global_epoch = max(global_epoch, peer.epoch)
+
+        total_samples_accumulated = estimated_current_samples = 0
+        total_samples_per_second = self.performance_ema.eps
+
+        for peer in valid_peer_entries:
+            total_samples_per_second += peer.samples_per_second
+            if peer.epoch == global_epoch:
+                total_samples_accumulated += peer.samples_accumulated
+                estimated_current_samples += (
+                    peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
+                )
+            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
+            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
+
+        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
+        estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
+
+        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
+        time_to_next_fetch = float(
+            np.clip(
+                a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
+                a_min=self.min_refresh_period,
+                a_max=self.max_refresh_period,
+            )
+        )
+        logger.log(
+            self.status_loglevel,
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
+            f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
+        )
+        return GlobalTrainingProgress(
+            global_epoch,
+            total_samples_accumulated,
+            target_batch_size=self.target_batch_size,
+            num_peers=num_peers,
+            num_clients=num_clients,
+            eta_next_epoch=current_time + estimated_time_to_next_epoch,
+            next_fetch_time=current_time + time_to_next_fetch,
+        )
+
+    def shutdown(self):
+        """Permanently disable all tracking activity"""
+        self.shutdown_triggered.set()
+        self.should_report_progress.set()
+        self.global_state_updated.set()
+        self.shutdown_complete.wait()
+        self.dht.store(
+            self.training_progress_key,
+            subkey=self._local_public_key,
+            value=None,
+            expiration_time=get_dht_time() + self.metadata_expiration,
+        )

+ 109 - 0
tests/test_optimizer.py

@@ -1,3 +1,5 @@
+import ctypes
+import multiprocessing as mp
 import time
 from functools import partial
 
@@ -10,7 +12,9 @@ import torch.nn.functional as F
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
+from hivemind.utils.crypto import RSAPrivateKey
 
 
 @pytest.mark.forked
@@ -170,3 +174,108 @@ def test_load_state_from_peers():
     assert avgr1.local_epoch == 1337
     assert torch.all(model1.weight == 42).item()
     assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
+
+
+@pytest.mark.forked
+def test_progress_tracker():
+    # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
+    prefix = "my_exp"
+    target_batch_size = 256
+    dht_root = hivemind.DHT(start=True)
+    barrier = mp.Barrier(parties=5)
+    delayed_start_evt = mp.Event()
+    finished_evt = mp.Event()
+    emas = mp.Array(ctypes.c_double, 5)
+
+    def run_worker(index: int, batch_size: int, period: float, **kwargs):
+        dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
+        tracker = ProgressTracker(
+            dht,
+            prefix,
+            target_batch_size,
+            start=True,
+            min_refresh_period=0.1,
+            default_refresh_period=0.2,
+            max_refresh_period=0.5,
+            private_key=RSAPrivateKey(),
+            **kwargs,
+        )
+
+        barrier.wait()
+        if index == 4:
+            delayed_start_evt.wait()
+
+        local_epoch = 2 if index == 4 else 0
+        samples_accumulated = 0
+
+        while True:
+            time.sleep(period)
+            if finished_evt.is_set():
+                break
+
+            samples_accumulated += batch_size
+            tracker.report_local_progress(local_epoch, samples_accumulated)
+
+            if tracker.ready_to_update_epoch:
+                with tracker.pause_updates():
+                    local_epoch = tracker.update_epoch(local_epoch + 1)
+                    samples_accumulated = 0
+
+                if index == 4 and local_epoch >= 5:
+                    time.sleep(0.5)
+                    break
+
+        emas[index] = tracker.performance_ema.samples_per_second
+        tracker.shutdown()
+        dht.shutdown()
+
+    workers = [
+        mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
+        mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
+        mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
+        mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
+    ]
+    for worker in workers:
+        worker.start()
+
+    tracker = ProgressTracker(
+        dht_root,
+        prefix,
+        target_batch_size,
+        start=True,
+        min_refresh_period=0.1,
+        default_refresh_period=0.2,
+        max_refresh_period=0.5,
+    )
+    barrier.wait()
+
+    current_step = 0
+    last_timestamp = hivemind.get_dht_time()
+    step_time_deltas = []
+
+    while current_step < 6:
+        time.sleep(0.1)
+        if tracker.global_progress.epoch > current_step:
+            time_delta = hivemind.get_dht_time() - last_timestamp
+            current_step = tracker.global_progress.epoch
+            if current_step == 2:
+                delayed_start_evt.set()
+
+            last_timestamp = hivemind.get_dht_time()
+            step_time_deltas.append(time_delta)
+
+    finished_evt.set()
+    for worker in workers:
+        worker.join()
+
+    tracker.shutdown()
+    dht_root.shutdown()
+    assert not tracker.is_alive()
+
+    mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
+    for i in (0, 1, 5):  # Without the 4th worker (the fastest one)
+        assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
+    for i in (2, 3, 4):  # With the 4th worker
+        assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
+    assert emas[1] < emas[2] < emas[3] < emas[4]
+    assert tracker.performance_ema.samples_per_second < 1e-9