progress_tracker.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import asyncio
  2. import contextlib
  3. import logging
  4. import threading
  5. from dataclasses import dataclass
  6. from typing import Dict, Optional
  7. import numpy as np
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.dht import DHT
  10. from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
  11. from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
  12. from hivemind.utils.crypto import RSAPrivateKey
  13. from hivemind.utils.performance_ema import PerformanceEMA
  14. logger = get_logger(__name__)
  15. @dataclass(frozen=False)
  16. class GlobalTrainingProgress:
  17. epoch: int
  18. samples_accumulated: int
  19. target_batch_size: int
  20. num_peers: int
  21. num_clients: int
  22. eta_next_epoch: float
  23. next_fetch_time: float
  24. class LocalTrainingProgress(BaseModel):
  25. peer_id: bytes
  26. epoch: conint(ge=0, strict=True)
  27. samples_accumulated: conint(ge=0, strict=True)
  28. samples_per_second: confloat(ge=0.0, strict=True)
  29. time: StrictFloat
  30. client_mode: StrictBool
  31. class TrainingProgressSchema(BaseModel):
  32. progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
  33. class ProgressTracker(threading.Thread):
  34. """
  35. Auxiliary class that keeps track of local & global training progress, measured in epochs.
  36. An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
  37. Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
  38. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  39. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  40. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  41. :param expected_drift_peers: assume that this many new peers can join between epochs
  42. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
  43. :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  44. refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
  45. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
  46. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  47. Example:
  48. >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
  49. >>> local_epoch, local_samples = 0, 0
  50. >>> while True:
  51. >>> accumulate_gradients(batch_size=32)
  52. >>> local_samples += 32
  53. >>> tracker.report_local_progress(local_epoch, local_samples)
  54. >>> if local_epoch < tracker.global_progress.epoch:
  55. >>> download_state_from_peers() # if peer is out of sync, synchronize it with the swarm
  56. >>> if tracker.accumulated_enough_samples:
  57. >>> with tracker.pause_updates():
  58. >>> aggregate_gradients_with_peers()
  59. >>> update_model_parameters()
  60. >>> local_epoch = tracker.update_epoch(local_epoch + 1)
  61. >>> local_samples = 0
  62. """
  63. def __init__(
  64. self,
  65. dht: DHT,
  66. prefix: str,
  67. target_batch_size: int,
  68. *,
  69. client_mode: Optional[bool] = None,
  70. min_refresh_period: float = 0.5,
  71. max_refresh_period: float = 30,
  72. default_refresh_period: float = 3,
  73. expected_drift_peers: float = 3,
  74. expected_drift_rate: float = 0.2,
  75. performance_ema_alpha: float = 0.1,
  76. metadata_expiration: float = 30.0,
  77. status_loglevel: int = logging.DEBUG,
  78. private_key: Optional[RSAPrivateKey] = None,
  79. daemon: bool = True,
  80. start: bool,
  81. ):
  82. client_mode = client_mode if client_mode is not None else dht.client_mode
  83. self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
  84. self.training_progress_key = f"{self.prefix}_progress"
  85. self.target_batch_size = target_batch_size
  86. self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
  87. self.default_refresh_period = default_refresh_period
  88. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  89. self.status_loglevel = status_loglevel
  90. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  91. self.metadata_expiration = metadata_expiration
  92. signature_validator = RSASignatureValidator(private_key)
  93. self._local_public_key = signature_validator.local_public_key
  94. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
  95. # report the collaboration progress periodically or in background
  96. self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
  97. metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
  98. self.global_progress = self._parse_swarm_progress_data(metadata)
  99. self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
  100. self.should_report_progress = threading.Event()
  101. self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
  102. super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
  103. if start:
  104. self.start()
  105. @property
  106. def global_epoch(self) -> int:
  107. return self.global_progress.epoch
  108. @property
  109. def ready_to_update_epoch(self) -> bool:
  110. """Whether or not this peer can increment epoch right away."""
  111. return (
  112. self.global_epoch > self.local_progress.epoch
  113. or self.global_progress.samples_accumulated >= self.target_batch_size
  114. or get_dht_time() >= self.global_progress.eta_next_epoch
  115. )
  116. @property
  117. def estimated_next_update_time(self) -> DHTExpiration:
  118. """Estimate (absolute) time when this peer should increment epoch"""
  119. if self.ready_to_update_epoch:
  120. return get_dht_time()
  121. return self.global_progress.eta_next_epoch
  122. def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
  123. return LocalTrainingProgress(
  124. peer_id=self.dht.peer_id.to_bytes(),
  125. epoch=local_epoch,
  126. samples_accumulated=samples_accumulated,
  127. samples_per_second=self.performance_ema.samples_per_second,
  128. time=get_dht_time(),
  129. client_mode=self.client_mode,
  130. )
  131. def report_local_progress(self, local_epoch: int, samples_accumulated: int):
  132. """Update the number of locally accumulated samples and notify to other peers about this."""
  133. extra_samples = samples_accumulated - self.local_progress.samples_accumulated
  134. if extra_samples > 0:
  135. self.performance_ema.update(task_size=extra_samples)
  136. logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
  137. else:
  138. logger.debug("Resetting performance timestamp to current time (progress was reset)")
  139. self.performance_ema.reset_timer()
  140. self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
  141. self.should_report_progress.set()
  142. @contextlib.contextmanager
  143. def pause_updates(self):
  144. """Temporarily stop progress tracker from updating global training state"""
  145. with self.lock_global_progress, self.performance_ema.pause():
  146. yield
  147. def update_epoch(self, new_epoch: Optional[int] = None) -> int:
  148. """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
  149. assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
  150. if new_epoch is None:
  151. new_epoch = self.local_progress.epoch + 1
  152. if new_epoch > self.global_progress.epoch:
  153. self.global_progress.epoch = new_epoch
  154. self.global_progress.samples_accumulated = 0
  155. self.global_progress.eta_next_epoch = float("inf")
  156. self.report_local_progress(new_epoch, samples_accumulated=0)
  157. return new_epoch
  158. def run(self):
  159. loop = asyncio.new_event_loop()
  160. asyncio.set_event_loop(loop)
  161. loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
  162. self.shutdown_complete.set()
  163. async def _progress_reporter(self):
  164. """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
  165. last_report_time = -float("inf")
  166. store_task = None
  167. try:
  168. while not self.shutdown_triggered.is_set():
  169. wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
  170. logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
  171. await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
  172. if self.should_report_progress.is_set():
  173. logger.debug(f"Progress update triggered by report_local_progress.")
  174. self.should_report_progress.clear()
  175. else:
  176. logger.debug(f"Progress update triggered by metadata_expiration.")
  177. local_progress = self.local_progress
  178. last_report_time = get_dht_time()
  179. store_task = asyncio.create_task(asyncio.wait_for(
  180. self.dht.store(
  181. key=self.training_progress_key,
  182. subkey=self._local_public_key,
  183. value=local_progress.dict(),
  184. expiration_time=last_report_time + self.metadata_expiration,
  185. return_future=True,
  186. ), timeout=self.metadata_expiration))
  187. finally:
  188. logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
  189. if store_task is not None and not store_task.done():
  190. store_task.cancel()
  191. async def _progress_fetcher(self):
  192. """
  193. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  194. """
  195. loop = asyncio.get_event_loop()
  196. shutdown_checker = asyncio.create_task(loop.run_in_executor(None, self.shutdown_triggered.wait))
  197. async def _fetch_progress_unless_shutdown_triggered():
  198. """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
  199. getter = asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
  200. await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
  201. if self.shutdown_triggered.is_set():
  202. return
  203. return await getter
  204. try:
  205. while not self.shutdown_triggered.is_set():
  206. time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
  207. state_updated_externally = await loop.run_in_executor(
  208. None, self.global_state_updated.wait, time_to_next_update
  209. )
  210. if state_updated_externally:
  211. self.global_state_updated.clear()
  212. continue
  213. async with enter_asynchronously(self.lock_global_progress):
  214. maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
  215. if self.shutdown_triggered.is_set():
  216. break
  217. metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
  218. self.global_progress = self._parse_swarm_progress_data(metadata)
  219. finally:
  220. logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
  221. def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
  222. """Read performance statistics reported by peers, estimate progress towards next batch"""
  223. current_time = get_dht_time()
  224. if not isinstance(metadata, dict) or len(metadata) == 0:
  225. logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
  226. samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
  227. local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
  228. return GlobalTrainingProgress(
  229. self.local_progress.epoch,
  230. self.local_progress.samples_accumulated,
  231. self.target_batch_size,
  232. num_peers=0,
  233. num_clients=0,
  234. eta_next_epoch=current_time + local_eta_next_epoch,
  235. next_fetch_time=current_time + self.default_refresh_period,
  236. )
  237. valid_peer_entries = [
  238. LocalTrainingProgress.parse_obj(peer_state.value)
  239. for peer_state in metadata.values()
  240. if peer_state.value is not None
  241. ]
  242. num_peers = len(valid_peer_entries)
  243. num_clients = sum(peer.client_mode for peer in valid_peer_entries)
  244. global_epoch = self.local_progress.epoch
  245. for peer in valid_peer_entries:
  246. if not peer.client_mode:
  247. global_epoch = max(global_epoch, peer.epoch)
  248. total_samples_accumulated = estimated_current_samples = 0
  249. total_samples_per_second = self.performance_ema.eps
  250. for peer in valid_peer_entries:
  251. total_samples_per_second += peer.samples_per_second
  252. if peer.epoch == global_epoch:
  253. total_samples_accumulated += peer.samples_accumulated
  254. estimated_current_samples += (
  255. peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
  256. )
  257. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  258. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  259. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  260. estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
  261. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  262. time_to_next_fetch = float(
  263. np.clip(
  264. a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
  265. a_min=self.min_refresh_period,
  266. a_max=self.max_refresh_period,
  267. )
  268. )
  269. logger.log(
  270. self.status_loglevel,
  271. f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
  272. f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
  273. )
  274. return GlobalTrainingProgress(
  275. global_epoch,
  276. total_samples_accumulated,
  277. target_batch_size=self.target_batch_size,
  278. num_peers=num_peers,
  279. num_clients=num_clients,
  280. eta_next_epoch=current_time + estimated_time_to_next_epoch,
  281. next_fetch_time=current_time + time_to_next_fetch,
  282. )
  283. def shutdown(self, timeout: Optional[float] = None):
  284. """Permanently disable all tracking activity"""
  285. self.shutdown_triggered.set()
  286. self.should_report_progress.set()
  287. self.global_state_updated.set()
  288. self.shutdown_complete.wait(timeout)
  289. self.dht.store(
  290. self.training_progress_key,
  291. subkey=self._local_public_key,
  292. value=None,
  293. expiration_time=get_dht_time() + self.metadata_expiration,
  294. return_future=True
  295. )