123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- import asyncio
- import contextlib
- import logging
- import threading
- from dataclasses import dataclass
- from typing import Dict, Optional
- import numpy as np
- from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
- from hivemind.dht import DHT
- from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
- from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
- from hivemind.utils.crypto import RSAPrivateKey
- from hivemind.utils.performance_ema import PerformanceEMA
- logger = get_logger(__name__)
- @dataclass(frozen=False)
- class GlobalTrainingProgress:
- epoch: int
- samples_accumulated: int
- target_batch_size: int
- num_peers: int
- num_clients: int
- eta_next_epoch: float
- next_fetch_time: float
- class LocalTrainingProgress(BaseModel):
- peer_id: bytes
- epoch: conint(ge=0, strict=True)
- samples_accumulated: conint(ge=0, strict=True)
- samples_per_second: confloat(ge=0.0, strict=True)
- time: StrictFloat
- client_mode: StrictBool
- class TrainingProgressSchema(BaseModel):
- progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
- class ProgressTracker(threading.Thread):
- """
- Auxiliary class that keeps track of local & global training progress, measured in epochs.
- An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
- Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
- :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
- :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
- :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
- :param expected_drift_peers: assume that this many new peers can join between epochs
- :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
- :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
- refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
- :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
- :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
- Example:
- >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
- >>> local_epoch, local_samples = 0, 0
- >>> while True:
- >>> accumulate_gradients(batch_size=32)
- >>> local_samples += 32
- >>> tracker.report_local_progress(local_epoch, local_samples)
- >>> if local_epoch < tracker.global_progress.epoch:
- >>> download_state_from_peers() # if peer is out of sync, synchronize it with the swarm
- >>> if tracker.accumulated_enough_samples:
- >>> with tracker.pause_updates():
- >>> aggregate_gradients_with_peers()
- >>> update_model_parameters()
- >>> local_epoch = tracker.update_epoch(local_epoch + 1)
- >>> local_samples = 0
- """
- def __init__(
- self,
- dht: DHT,
- prefix: str,
- target_batch_size: int,
- *,
- client_mode: Optional[bool] = None,
- min_refresh_period: float = 0.5,
- max_refresh_period: float = 30,
- default_refresh_period: float = 3,
- expected_drift_peers: float = 3,
- expected_drift_rate: float = 0.2,
- performance_ema_alpha: float = 0.1,
- metadata_expiration: float = 30.0,
- status_loglevel: int = logging.DEBUG,
- private_key: Optional[RSAPrivateKey] = None,
- daemon: bool = True,
- start: bool,
- ):
- client_mode = client_mode if client_mode is not None else dht.client_mode
- self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
- self.training_progress_key = f"{self.prefix}_progress"
- self.target_batch_size = target_batch_size
- self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
- self.default_refresh_period = default_refresh_period
- self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
- self.status_loglevel = status_loglevel
- self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
- self.metadata_expiration = metadata_expiration
- signature_validator = RSASignatureValidator(private_key)
- self._local_public_key = signature_validator.local_public_key
- dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
- # report the collaboration progress periodically or in background
- self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
- metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
- self.global_progress = self._parse_swarm_progress_data(metadata)
- self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
- self.should_report_progress = threading.Event()
- self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
- super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
- if start:
- self.start()
- @property
- def global_epoch(self) -> int:
- return self.global_progress.epoch
- @property
- def ready_to_update_epoch(self) -> bool:
- """Whether or not this peer can increment epoch right away."""
- return (
- self.global_epoch > self.local_progress.epoch
- or self.global_progress.samples_accumulated >= self.target_batch_size
- or get_dht_time() >= self.global_progress.eta_next_epoch
- )
- @property
- def estimated_next_update_time(self) -> DHTExpiration:
- """Estimate (absolute) time when this peer should increment epoch"""
- if self.ready_to_update_epoch:
- return get_dht_time()
- return self.global_progress.eta_next_epoch
- def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
- return LocalTrainingProgress(
- peer_id=self.dht.peer_id.to_bytes(),
- epoch=local_epoch,
- samples_accumulated=samples_accumulated,
- samples_per_second=self.performance_ema.samples_per_second,
- time=get_dht_time(),
- client_mode=self.client_mode,
- )
- def report_local_progress(self, local_epoch: int, samples_accumulated: int):
- """Update the number of locally accumulated samples and notify to other peers about this."""
- extra_samples = samples_accumulated - self.local_progress.samples_accumulated
- if extra_samples > 0:
- self.performance_ema.update(task_size=extra_samples)
- logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
- else:
- logger.debug("Resetting performance timestamp to current time (progress was reset)")
- self.performance_ema.reset_timer()
- self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
- self.should_report_progress.set()
- @contextlib.contextmanager
- def pause_updates(self):
- """Temporarily stop progress tracker from updating global training state"""
- with self.lock_global_progress, self.performance_ema.pause():
- yield
- def update_epoch(self, new_epoch: Optional[int] = None) -> int:
- """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
- assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
- if new_epoch is None:
- new_epoch = self.local_progress.epoch + 1
- if new_epoch > self.global_progress.epoch:
- self.global_progress.epoch = new_epoch
- self.global_progress.samples_accumulated = 0
- self.global_progress.eta_next_epoch = float("inf")
- self.report_local_progress(new_epoch, samples_accumulated=0)
- return new_epoch
- def run(self):
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
- self.shutdown_complete.set()
- async def _progress_reporter(self):
- """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
- last_report_time = -float("inf")
- store_task = None
- try:
- while not self.shutdown_triggered.is_set():
- wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
- logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
- await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
- if self.should_report_progress.is_set():
- logger.debug(f"Progress update triggered by report_local_progress.")
- self.should_report_progress.clear()
- else:
- logger.debug(f"Progress update triggered by metadata_expiration.")
- local_progress = self.local_progress
- last_report_time = get_dht_time()
- store_task = asyncio.create_task(asyncio.wait_for(
- self.dht.store(
- key=self.training_progress_key,
- subkey=self._local_public_key,
- value=local_progress.dict(),
- expiration_time=last_report_time + self.metadata_expiration,
- return_future=True,
- ), timeout=self.metadata_expiration))
- finally:
- logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
- if store_task is not None and not store_task.done():
- store_task.cancel()
- async def _progress_fetcher(self):
- """
- Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
- """
- loop = asyncio.get_event_loop()
- shutdown_checker = asyncio.create_task(loop.run_in_executor(None, self.shutdown_triggered.wait))
- async def _fetch_progress_unless_shutdown_triggered():
- """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
- getter = asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
- await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
- if self.shutdown_triggered.is_set():
- return
- return await getter
- try:
- while not self.shutdown_triggered.is_set():
- time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
- state_updated_externally = await loop.run_in_executor(
- None, self.global_state_updated.wait, time_to_next_update
- )
- if state_updated_externally:
- self.global_state_updated.clear()
- continue
- async with enter_asynchronously(self.lock_global_progress):
- maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
- if self.shutdown_triggered.is_set():
- break
- metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
- self.global_progress = self._parse_swarm_progress_data(metadata)
- finally:
- logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
- def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
- """Read performance statistics reported by peers, estimate progress towards next batch"""
- current_time = get_dht_time()
- if not isinstance(metadata, dict) or len(metadata) == 0:
- logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
- samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
- local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
- return GlobalTrainingProgress(
- self.local_progress.epoch,
- self.local_progress.samples_accumulated,
- self.target_batch_size,
- num_peers=0,
- num_clients=0,
- eta_next_epoch=current_time + local_eta_next_epoch,
- next_fetch_time=current_time + self.default_refresh_period,
- )
- valid_peer_entries = [
- LocalTrainingProgress.parse_obj(peer_state.value)
- for peer_state in metadata.values()
- if peer_state.value is not None
- ]
- num_peers = len(valid_peer_entries)
- num_clients = sum(peer.client_mode for peer in valid_peer_entries)
- global_epoch = self.local_progress.epoch
- for peer in valid_peer_entries:
- if not peer.client_mode:
- global_epoch = max(global_epoch, peer.epoch)
- total_samples_accumulated = estimated_current_samples = 0
- total_samples_per_second = self.performance_ema.eps
- for peer in valid_peer_entries:
- total_samples_per_second += peer.samples_per_second
- if peer.epoch == global_epoch:
- total_samples_accumulated += peer.samples_accumulated
- estimated_current_samples += (
- peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
- )
- # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
- # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
- estimated_samples_remaining = self.target_batch_size - estimated_current_samples
- estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
- expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
- time_to_next_fetch = float(
- np.clip(
- a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
- a_min=self.min_refresh_period,
- a_max=self.max_refresh_period,
- )
- )
- logger.log(
- self.status_loglevel,
- f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
- f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
- )
- return GlobalTrainingProgress(
- global_epoch,
- total_samples_accumulated,
- target_batch_size=self.target_batch_size,
- num_peers=num_peers,
- num_clients=num_clients,
- eta_next_epoch=current_time + estimated_time_to_next_epoch,
- next_fetch_time=current_time + time_to_next_fetch,
- )
- def shutdown(self, timeout: Optional[float] = None):
- """Permanently disable all tracking activity"""
- self.shutdown_triggered.set()
- self.should_report_progress.set()
- self.global_state_updated.set()
- self.shutdown_complete.wait(timeout)
- self.dht.store(
- self.training_progress_key,
- subkey=self._local_public_key,
- value=None,
- expiration_time=get_dht_time() + self.metadata_expiration,
- return_future=True
- )
|