progress_tracker.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import asyncio
  2. import contextlib
  3. import logging
  4. import threading
  5. from dataclasses import dataclass
  6. from typing import Dict, Optional
  7. import numpy as np
  8. from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
  9. from hivemind.dht import DHT
  10. from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
  11. from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
  12. from hivemind.utils.crypto import RSAPrivateKey
  13. from hivemind.utils.performance_ema import PerformanceEMA
  14. logger = get_logger(__name__)
  15. @dataclass(frozen=False)
  16. class GlobalTrainingProgress:
  17. epoch: int
  18. samples_accumulated: int
  19. target_batch_size: int
  20. num_peers: int
  21. num_clients: int
  22. eta_next_epoch: float
  23. next_fetch_time: float
  24. class LocalTrainingProgress(BaseModel):
  25. peer_id: bytes
  26. epoch: conint(ge=0, strict=True)
  27. samples_accumulated: conint(ge=0, strict=True)
  28. samples_per_second: confloat(ge=0.0, strict=True)
  29. time: StrictFloat
  30. client_mode: StrictBool
  31. class TrainingProgressSchema(BaseModel):
  32. progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
  33. class ProgressTracker(threading.Thread):
  34. """
  35. Auxiliary class that keeps track of local & global training progress, measured in epochs.
  36. An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
  37. Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
  38. :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
  39. :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
  40. :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
  41. :param expected_drift_peers: assume that this many new peers can join between epochs
  42. :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
  43. :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
  44. refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
  45. :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
  46. :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
  47. Example:
  48. >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
  49. >>> local_epoch, local_samples = 0, 0
  50. >>> while True:
  51. >>> accumulate_gradients(batch_size=32)
  52. >>> local_samples += 32
  53. >>> tracker.report_local_progress(local_epoch, local_samples)
  54. >>> if local_epoch < tracker.global_progress.epoch:
  55. >>> download_state_from_peers() # if peer is out of sync, synchronize it with the swarm
  56. >>> if tracker.accumulated_enough_samples:
  57. >>> with tracker.pause_updates():
  58. >>> aggregate_gradients_with_peers()
  59. >>> update_model_parameters()
  60. >>> local_epoch = tracker.update_epoch(local_epoch + 1)
  61. >>> local_samples = 0
  62. """
  63. def __init__(
  64. self,
  65. dht: DHT,
  66. prefix: str,
  67. target_batch_size: int,
  68. *,
  69. client_mode: Optional[bool] = None,
  70. min_refresh_period: float = 0.5,
  71. max_refresh_period: float = 10,
  72. default_refresh_period: float = 3,
  73. expected_drift_peers: float = 3,
  74. expected_drift_rate: float = 0.2,
  75. performance_ema_alpha: float = 0.1,
  76. metadata_expiration: float = 30.0,
  77. status_loglevel: int = logging.DEBUG,
  78. private_key: Optional[RSAPrivateKey] = None,
  79. daemon: bool = True,
  80. start: bool,
  81. ):
  82. client_mode = client_mode if client_mode is not None else dht.client_mode
  83. self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
  84. self.training_progress_key = f"{self.prefix}_progress"
  85. self.target_batch_size = target_batch_size
  86. self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
  87. self.default_refresh_period = default_refresh_period
  88. self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
  89. self.status_loglevel = status_loglevel
  90. self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
  91. self.metadata_expiration = metadata_expiration
  92. signature_validator = RSASignatureValidator(private_key)
  93. self._local_public_key = signature_validator.local_public_key
  94. dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
  95. # report the collaboration progress periodically or in background
  96. self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
  97. metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
  98. self.global_progress = self._parse_swarm_progress_data(metadata)
  99. self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
  100. self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
  101. self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
  102. super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
  103. if start:
  104. self.start()
  105. @property
  106. def global_epoch(self) -> int:
  107. return self.global_progress.epoch
  108. @property
  109. def ready_to_update_epoch(self) -> bool:
  110. """Whether or not this peer can increment epoch right away."""
  111. return (
  112. self.global_epoch > self.local_progress.epoch
  113. or self.global_progress.samples_accumulated >= self.target_batch_size
  114. or get_dht_time() >= self.global_progress.eta_next_epoch
  115. )
  116. @property
  117. def estimated_next_update_time(self) -> DHTExpiration:
  118. """Estimate (absolute) time when this peer should increment epoch"""
  119. if self.ready_to_update_epoch:
  120. return get_dht_time()
  121. return self.global_progress.eta_next_epoch
  122. def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
  123. return LocalTrainingProgress(
  124. peer_id=self.dht.peer_id.to_bytes(),
  125. epoch=local_epoch,
  126. samples_accumulated=samples_accumulated,
  127. samples_per_second=self.performance_ema.samples_per_second,
  128. time=get_dht_time(),
  129. client_mode=self.client_mode,
  130. )
  131. def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
  132. """Update the number of locally accumulated samples and notify to other peers about this."""
  133. extra_samples = samples_accumulated - self.local_progress.samples_accumulated
  134. if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
  135. self.global_progress.samples_accumulated += extra_samples
  136. # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
  137. if extra_samples > 0:
  138. self.performance_ema.update(task_size=extra_samples)
  139. logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
  140. else:
  141. logger.debug("Resetting performance timestamp to current time (progress was reset)")
  142. self.performance_ema.reset_timer()
  143. self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
  144. self.should_report_progress.set()
  145. @contextlib.contextmanager
  146. def pause_updates(self):
  147. """Temporarily stop progress tracker from updating global training state"""
  148. with self.lock_global_progress, self.performance_ema.pause():
  149. yield
  150. def update_epoch(self, new_epoch: Optional[int] = None) -> int:
  151. """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
  152. assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
  153. if new_epoch is None:
  154. new_epoch = self.local_progress.epoch + 1
  155. if new_epoch > self.global_progress.epoch:
  156. self.global_progress.epoch = new_epoch
  157. self.global_progress.samples_accumulated = 0
  158. self.global_progress.eta_next_epoch = float("inf")
  159. self.report_local_progress(new_epoch, samples_accumulated=0)
  160. self.fetched_global_progress_this_epoch.clear()
  161. return new_epoch
  162. def run(self):
  163. loop = asyncio.new_event_loop()
  164. asyncio.set_event_loop(loop)
  165. loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
  166. self.shutdown_complete.set()
  167. async def _progress_reporter(self):
  168. """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
  169. last_report_time = -float("inf")
  170. store_task = None
  171. try:
  172. while not self.shutdown_triggered.is_set():
  173. wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
  174. logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
  175. await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
  176. if self.should_report_progress.is_set():
  177. logger.debug(f"Progress update triggered by report_local_progress.")
  178. self.should_report_progress.clear()
  179. else:
  180. logger.debug(f"Progress update triggered by metadata_expiration.")
  181. local_progress = self.local_progress
  182. last_report_time = get_dht_time()
  183. store_task = asyncio.create_task(
  184. asyncio.wait_for(
  185. self.dht.store(
  186. key=self.training_progress_key,
  187. subkey=self._local_public_key,
  188. value=local_progress.dict(),
  189. expiration_time=last_report_time + self.metadata_expiration,
  190. return_future=True,
  191. ),
  192. timeout=self.metadata_expiration,
  193. )
  194. )
  195. finally:
  196. logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
  197. if store_task is not None:
  198. store_task.cancel()
  199. async def _progress_fetcher(self):
  200. """
  201. Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
  202. """
  203. loop = asyncio.get_event_loop()
  204. shutdown_checker = asyncio.create_task(
  205. asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
  206. )
  207. async def _fetch_progress_unless_shutdown_triggered():
  208. """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
  209. getter = asyncio.create_task(
  210. asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
  211. )
  212. await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
  213. if self.shutdown_triggered.is_set():
  214. return
  215. return await getter
  216. try:
  217. while not self.shutdown_triggered.is_set():
  218. time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
  219. state_updated_externally = await loop.run_in_executor(
  220. None, self.global_state_updated.wait, time_to_next_update
  221. )
  222. if state_updated_externally:
  223. self.global_state_updated.clear()
  224. continue
  225. async with enter_asynchronously(self.lock_global_progress):
  226. maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
  227. if self.shutdown_triggered.is_set():
  228. break
  229. metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
  230. self.global_progress = self._parse_swarm_progress_data(metadata)
  231. self.fetched_global_progress_this_epoch.set()
  232. finally:
  233. logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
  234. def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
  235. """Read performance statistics reported by peers, estimate progress towards next batch"""
  236. current_time = get_dht_time()
  237. if not isinstance(metadata, dict) or len(metadata) == 0:
  238. logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
  239. samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
  240. local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
  241. return GlobalTrainingProgress(
  242. self.local_progress.epoch,
  243. self.local_progress.samples_accumulated,
  244. self.target_batch_size,
  245. num_peers=0,
  246. num_clients=0,
  247. eta_next_epoch=current_time + local_eta_next_epoch,
  248. next_fetch_time=current_time + self.default_refresh_period,
  249. )
  250. valid_peer_entries = [
  251. LocalTrainingProgress.parse_obj(peer_state.value)
  252. for peer_state in metadata.values()
  253. if peer_state.value is not None
  254. ]
  255. num_peers = len(valid_peer_entries)
  256. num_clients = sum(peer.client_mode for peer in valid_peer_entries)
  257. global_epoch = self.local_progress.epoch
  258. for peer in valid_peer_entries:
  259. if not peer.client_mode:
  260. global_epoch = max(global_epoch, peer.epoch)
  261. total_samples_accumulated = estimated_current_samples = 0
  262. total_samples_per_second = self.performance_ema.eps
  263. for peer in valid_peer_entries:
  264. total_samples_per_second += peer.samples_per_second
  265. if peer.epoch == global_epoch:
  266. total_samples_accumulated += peer.samples_accumulated
  267. estimated_current_samples += (
  268. peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
  269. )
  270. # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
  271. # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
  272. estimated_samples_remaining = self.target_batch_size - estimated_current_samples
  273. estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
  274. expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
  275. time_to_next_fetch = float(
  276. np.clip(
  277. a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
  278. a_min=self.min_refresh_period,
  279. a_max=self.max_refresh_period,
  280. )
  281. )
  282. logger.log(
  283. self.status_loglevel,
  284. f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
  285. f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
  286. )
  287. return GlobalTrainingProgress(
  288. global_epoch,
  289. total_samples_accumulated,
  290. target_batch_size=self.target_batch_size,
  291. num_peers=num_peers,
  292. num_clients=num_clients,
  293. eta_next_epoch=current_time + estimated_time_to_next_epoch,
  294. next_fetch_time=current_time + time_to_next_fetch,
  295. )
  296. def shutdown(self, timeout: Optional[float] = None):
  297. """Permanently disable all tracking activity"""
  298. self.shutdown_triggered.set()
  299. self.should_report_progress.set()
  300. self.global_state_updated.set()
  301. self.shutdown_complete.wait(timeout)
  302. self.dht.store(
  303. self.training_progress_key,
  304. subkey=self._local_public_key,
  305. value=None,
  306. expiration_time=get_dht_time() + self.metadata_expiration,
  307. return_future=True,
  308. )