optimizer.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import time
  5. from functools import partial
  6. from typing import Callable, Optional, Sequence, Union
  7. import torch
  8. from hivemind.averaging.control import AveragingStage, StepControl
  9. from hivemind.compression import CompressionBase, NoCompression
  10. from hivemind.dht import DHT
  11. from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
  12. from hivemind.optim.grad_scaler import GradScaler
  13. from hivemind.optim.power_ef_averager import PowerEFGradientAverager
  14. from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
  15. from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
  16. from hivemind.optim.state_averager import (
  17. LRSchedulerBase,
  18. OptimizerFactory,
  19. Parameters,
  20. ParamGroups,
  21. SchedulerFactory,
  22. TorchOptimizer,
  23. TrainingStateAverager,
  24. )
  25. from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
  26. logger = get_logger(__name__)
  27. class Optimizer(torch.optim.Optimizer):
  28. """
  29. hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
  30. By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
  31. There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
  32. or even fully asynchronous (grad_averager=None).
  33. :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
  34. >>> model = transformers.AutoModel("albert-xxlarge-v2")
  35. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
  36. >>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
  37. >>> params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
  38. >>> while True:
  39. >>> loss = compute_loss_on_batch(model, batch_size=4)
  40. >>> opt.zero_grad()
  41. >>> loss.backward()
  42. >>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
  43. By default, peers will perform the following steps:
  44. * accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
  45. * after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;
  46. * if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;
  47. Unlike regular training, your device may join midway through training, when other peers already made some progress.
  48. For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
  49. ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
  50. may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
  51. :What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
  52. coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
  53. Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
  54. At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
  55. updating the learning rate scheduler or simply averaging parameters (if using local updates).
  56. The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters.
  57. For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.
  58. :Configuration guide: This guide will help you set up your first collaborative training run. It covers the most
  59. important basic options, but ignores features that require significant changes to the training code.
  60. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
  61. >>> opt = hivemind.Optimizer(
  62. >>> dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
  63. >>> batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
  64. >>> # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
  65. >>> # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
  66. >>>
  67. >>> params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
  68. >>> # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
  69. >>> scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
  70. >>> # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
  71. >>>
  72. >>> offload_optimizer=True, # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
  73. >>> delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
  74. >>> # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
  75. >>>
  76. >>> grad_compression=hivemind.Float16Compression(), state_averaging_compression=hivemind.Float16Compression(),
  77. >>> # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
  78. >>> # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
  79. >>>
  80. >>> matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
  81. >>> averaging_timeout=60.0, # around of 2x the actual time it takes to run all-reduce
  82. >>> verbose=True # periodically report the training progress to the console (e.g. "Averaged with N peers")
  83. >>> ) # and you're done!
  84. :param dht: a running hivemind.DHT instance connected to other peers.
  85. :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
  86. **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
  87. Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
  88. ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
  89. individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
  90. :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch.
  91. The actual batch may be *slightly* larger due asynchrony (e.g. peers submit more gradients in the last second).
  92. :param batch_size_per_step: you should accumulate gradients over this many samples between calls to optimizer.step.
  93. :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params).
  94. :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer.
  95. **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require
  96. and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.
  97. :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
  98. The learning rate scheduler will adjust learning rate based on global epoch, not the number of
  99. local calls to optimizer.step; this is required to keep different peers synchronized.
  100. :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds.
  101. Increase if you see "averaged gradients with N peers" where N is below 0.9x the real siee on >=25% of epochs.
  102. When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
  103. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
  104. Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
  105. Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
  106. :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
  107. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
  108. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  109. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  110. :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
  111. :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
  112. :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
  113. :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
  114. if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
  115. :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
  116. This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
  117. The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
  118. hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
  119. gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
  120. :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
  121. :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
  122. :param grad_compression: compression strategy used for averaging gradients, default = no compression
  123. :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
  124. :param load_state_compression: compression strategy for loading state from peers, default = no compression
  125. :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
  126. :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
  127. :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
  128. :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
  129. :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer
  130. :param verbose: if True, report internal events such as accumilating gradients and running background tasks
  131. :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
  132. is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
  133. """
  134. def __init__(
  135. self,
  136. *,
  137. dht: DHT,
  138. run_id: str,
  139. target_batch_size: int,
  140. batch_size_per_step: Optional[int] = None,
  141. optimizer: Union[TorchOptimizer, OptimizerFactory],
  142. params: Optional[Union[Parameters, ParamGroups]] = None,
  143. scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
  144. matchmaking_time: Optional[float] = 15.0,
  145. averaging_timeout: Optional[float] = 60.0,
  146. allreduce_timeout: Optional[float] = None,
  147. next_chunk_timeout: Optional[float] = None,
  148. load_state_timeout: float = 600.0,
  149. reuse_grad_buffers: bool = False,
  150. offload_optimizer: Optional[bool] = None,
  151. delay_optimizer_step: Optional[bool] = None,
  152. delay_grad_averaging: bool = False,
  153. delay_state_averaging: bool = True,
  154. average_state_every: int = 1,
  155. client_mode: bool = None,
  156. auxiliary: bool = False,
  157. grad_compression: CompressionBase = NoCompression(),
  158. grad_averager: Optional[GradientAveragerFactory] = PowerEFGradientAverager.get_factory(averager_rank=32),
  159. use_ext_grad_buffer: bool = True,
  160. state_averaging_compression: CompressionBase = NoCompression(),
  161. load_state_compression: CompressionBase = NoCompression(),
  162. average_opt_statistics: Sequence[str] = (),
  163. extra_tensors: Sequence[torch.Tensor] = (),
  164. averager_opts: Optional[dict] = None,
  165. tracker_opts: Optional[dict] = None,
  166. performance_ema_alpha: float = 0.1,
  167. shutdown_timeout: float = 5,
  168. verbose: bool = False,
  169. ):
  170. self._parent_pid = os.getpid()
  171. client_mode = client_mode if client_mode is None else dht.client_mode
  172. delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
  173. offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
  174. allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
  175. next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
  176. assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
  177. assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
  178. assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
  179. if callable(optimizer) and params is not None:
  180. if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
  181. raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
  182. elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
  183. if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
  184. raise ValueError(
  185. "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
  186. "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
  187. )
  188. else:
  189. raise ValueError(
  190. "Please initialize the optimizer in one of the following two ways:\n"
  191. "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
  192. "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
  193. )
  194. if grad_averager is None:
  195. assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
  196. assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
  197. params = list(params) if params is not None else optimizer.param_groups
  198. if all(isinstance(p, torch.Tensor) for p in params):
  199. params = (dict(params=params),)
  200. self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
  201. self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
  202. self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
  203. self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
  204. self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
  205. self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
  206. self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
  207. self.next_chunk_timeout = next_chunk_timeout
  208. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  209. self.scheduled_grads: Optional[StepControl] = None
  210. self.scheduled_state: Optional[StepControl] = None
  211. self.tracker = self._make_progress_tracker(
  212. target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
  213. )
  214. averaged_grads = None
  215. if use_ext_grad_buffer:
  216. assert grad_averager is not None, "Use external gradient buffers only with working gradient averager."
  217. averaged_grads = [
  218. torch.rand((param.reshape((param.size(0), -1)).size(1), 32), device="cpu").share_memory_()
  219. for param_group in params
  220. for param in param_group["params"]
  221. ] + [
  222. torch.zeros_like(param).share_memory_()
  223. for param_group in params
  224. for param in param_group["params"]
  225. ]
  226. extra_tensors = [e for e in extra_tensors] + averaged_grads
  227. self.state_averager = self._make_state_averager(
  228. optimizer=optimizer,
  229. params=params,
  230. scheduler=scheduler,
  231. delta_rule_averaging=grad_averager is None and self.delay_state_averaging,
  232. compression=state_averaging_compression,
  233. state_compression=load_state_compression,
  234. average_opt_statistics=average_opt_statistics,
  235. performance_ema_alpha=performance_ema_alpha,
  236. extra_tensors=extra_tensors,
  237. **averager_opts or {},
  238. )
  239. if grad_averager:
  240. self.grad_averager = self._make_gradient_averager(
  241. reuse_grad_buffers=reuse_grad_buffers, grad_averager=grad_averager, averaged_grads=averaged_grads
  242. )
  243. else:
  244. self.grad_averager = None
  245. self._should_check_synchronization_on_update = True # used in self.should_load_state_from_peers
  246. self._schema_hash = self._compute_schema_hash()
  247. self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
  248. # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
  249. # used for pre-scheduling the averaging round in state_averager
  250. self._step_supports_amp_scaling = reuse_grad_buffers
  251. # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
  252. # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
  253. def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
  254. return TrainingStateAverager(
  255. dht=self.dht,
  256. prefix=f"{self.run_id}_state_averager",
  257. min_matchmaking_time=self.matchmaking_time,
  258. allreduce_timeout=self.allreduce_timeout,
  259. shutdown_timeout=self.shutdown_timeout,
  260. offload_optimizer=self.offload_optimizer,
  261. custom_gradients=self.offload_optimizer,
  262. status_loglevel=self.status_loglevel,
  263. next_chunk_timeout=self.next_chunk_timeout,
  264. client_mode=self.client_mode,
  265. auxiliary=self.auxiliary,
  266. start=True,
  267. **kwargs,
  268. )
  269. def _make_gradient_averager(self, grad_averager, **kwargs) -> GradientAverager:
  270. assert hasattr(self, "state_averager"), "must initialize state averager first"
  271. grad_averager = grad_averager(
  272. dht=self.dht,
  273. prefix=f"{self.run_id}_grad_averager",
  274. parameters=self.state_averager.main_parameters,
  275. min_matchmaking_time=self.matchmaking_time,
  276. allreduce_timeout=self.allreduce_timeout,
  277. shutdown_timeout=self.shutdown_timeout,
  278. next_chunk_timeout=self.next_chunk_timeout,
  279. client_mode=self.client_mode,
  280. auxiliary=self.auxiliary,
  281. start=True,
  282. **kwargs,
  283. )
  284. if self.offload_optimizer:
  285. optimized_param_groups = self.state_averager.optimizer.param_groups
  286. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  287. with grad_averager.get_tensors() as averaged_gradients:
  288. assert len(averaged_gradients) == len(optimized_parameters)
  289. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  290. opt_param.grad = averaged_grad
  291. return grad_averager
  292. def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
  293. return ProgressTracker(
  294. dht=self.dht,
  295. prefix=self.run_id,
  296. target_batch_size=target_batch_size,
  297. client_mode=self.client_mode,
  298. status_loglevel=self.status_loglevel,
  299. start=True,
  300. **kwargs,
  301. )
  302. def _compute_schema_hash(self) -> int:
  303. optimized_param_groups = self.state_averager.optimizer.param_groups
  304. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  305. param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
  306. # offloaded optimizer requires that gradient tensors are reused between iterations
  307. grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
  308. return hash((grad_ids, param_shapes))
  309. def is_alive(self) -> bool:
  310. return self.state_averager.is_alive()
  311. @property
  312. def local_epoch(self) -> int:
  313. """
  314. This worker's current epoch, kept synchronized with peers. If peer's local_epoch lags behind others, it will
  315. automatically re-synchronize by downloading state from another peer.
  316. An epoch corresponds to accumulating target_batch_size across all active devices.
  317. """
  318. return self.state_averager.local_epoch
  319. @property
  320. def local_progress(self) -> LocalTrainingProgress:
  321. return self.tracker.local_progress
  322. @property
  323. def use_local_updates(self) -> bool:
  324. return self.grad_averager is None
  325. @property
  326. def use_gradient_averaging(self) -> bool:
  327. return self.grad_averager is not None
  328. def step(
  329. self,
  330. closure: Optional[Callable[[], torch.Tensor]] = None,
  331. batch_size: Optional[int] = None,
  332. grad_scaler: Optional[GradScaler] = None,
  333. ):
  334. """
  335. Update training progress after accumulating another local batch size. Depending on the configuration, this will
  336. report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
  337. :param closure: A closure that reevaluates the model and returns the loss.
  338. :param batch_size: optional override for batch_size_per_step from init.
  339. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler.
  340. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  341. """
  342. if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
  343. raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
  344. if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
  345. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  346. if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
  347. raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
  348. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  349. # if delayed updates finished before step, apply these updates; otherwise do nothing
  350. self.state_averager.step(apply_delayed_updates=True)
  351. loss = None
  352. if closure is not None:
  353. with torch.enable_grad():
  354. loss = closure()
  355. if not self.auxiliary and self._should_load_state_from_peers():
  356. logger.log(self.status_loglevel, "Peer is out of sync")
  357. self.load_state_from_peers()
  358. return loss # local gradients were computed with out-of-sync parameters, must start over
  359. if self.use_gradient_averaging:
  360. # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
  361. if not self.auxiliary:
  362. grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
  363. if not grads_are_valid:
  364. return loss # local gradients were reset due to overflow, must start over
  365. self._maybe_schedule_gradient_averaging()
  366. self._maybe_schedule_state_averaging()
  367. else:
  368. # grad_averager=None: update parameters on every step independently of other peers
  369. if not self.auxiliary:
  370. if grad_scaler is not None:
  371. with grad_scaler.running_global_step():
  372. assert grad_scaler.unscale_(self)
  373. new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
  374. self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
  375. self._maybe_schedule_state_averaging()
  376. self.state_averager.step(
  377. increment_epoch=False,
  378. optimizer_step=True,
  379. delay_optimizer_step=self.delay_optimizer_step,
  380. grad_scaler=grad_scaler,
  381. )
  382. if self.tracker.ready_to_update_epoch:
  383. self._update_global_epoch(grad_scaler)
  384. return loss
  385. def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
  386. """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
  387. assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
  388. _epoch_start_time = time.perf_counter()
  389. with self.tracker.pause_updates():
  390. wait_for_trigger = None
  391. if self.use_gradient_averaging:
  392. logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
  393. if self.delay_optimizer_step:
  394. self.state_averager.step(wait_for_delayed_updates=True)
  395. began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
  396. if not began_averaging_gradients:
  397. # failed to start gradient averaging due to an internal error
  398. self.grad_averager.load_accumulators_into_averager_()
  399. elif self.delay_grad_averaging:
  400. # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
  401. wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
  402. else:
  403. # delay_grad_averaging=False, average gradients immediately
  404. self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
  405. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  406. swarm_not_empty = self.tracker.global_progress.num_peers > 1
  407. should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
  408. should_average_state = (
  409. swarm_not_empty
  410. and next_epoch % self.average_state_every == 0
  411. and not self.state_averager.averaging_in_progress
  412. )
  413. if should_average_state and self.scheduled_state is not None:
  414. if self.scheduled_state.triggered or self.scheduled_state.done():
  415. logger.log(
  416. self.status_loglevel,
  417. f"Not using pre-scheduled group for state averaging because it"
  418. f"was already used elsewhere: {self.scheduled_state}",
  419. )
  420. self.scheduled_state = None
  421. self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
  422. self.state_averager.step(
  423. increment_epoch=True,
  424. wait_for_trigger=wait_for_trigger,
  425. optimizer_step=should_perform_optimizer_step,
  426. delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
  427. grad_scaler=grad_scaler,
  428. averaging_round=should_average_state,
  429. delay_averaging=self.delay_state_averaging and not self.auxiliary,
  430. averaging_control=self.scheduled_state if should_average_state else None,
  431. averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
  432. )
  433. if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
  434. self.scheduled_state.cancel()
  435. self.scheduled_state = None
  436. self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
  437. self._should_check_synchronization_on_update = True
  438. # the above line ensures that peers check for *strict* synchronization once per epoch
  439. if not self.client_mode:
  440. self.state_averager.state_sharing_priority = self.local_epoch
  441. if self.use_gradient_averaging and not self.auxiliary:
  442. self.grad_averager.reset_accumulated_grads_()
  443. if not self.client_mode:
  444. self.grad_averager.state_sharing_priority = self.local_epoch
  445. logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}")
  446. def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
  447. """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
  448. if grad_scaler is not None:
  449. with grad_scaler.running_global_step():
  450. assert grad_scaler.unscale_(self)
  451. began_averaging_gradients = False
  452. if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
  453. logger.log(
  454. self.status_loglevel,
  455. f"Not using pre-scheduled group for state averaging because it"
  456. f"was already used elsewhere: {self.scheduled_state}",
  457. )
  458. self.scheduled_grads = None
  459. elif self.tracker.global_progress.num_peers > 1:
  460. try:
  461. self.scheduled_grads = self.grad_averager.step(
  462. control=self.scheduled_grads, reset_accumulators=True, wait=False
  463. )
  464. began_averaging_gradients = True
  465. except BaseException as e:
  466. logger.exception(e)
  467. if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
  468. if self.tracker.global_progress.num_peers > 1:
  469. logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
  470. self._tag_along_with_zero_weight(self.scheduled_grads)
  471. else:
  472. logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
  473. self._load_local_gradients_into_optimizer()
  474. self.scheduled_grads.cancel()
  475. self.scheduled_grads = None
  476. return began_averaging_gradients
  477. def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
  478. """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
  479. assert not self.use_local_updates and not self.auxiliary
  480. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  481. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  482. self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
  483. self.grad_averager.reset_accumulated_grads_()
  484. return False
  485. self.grad_averager.accumulate_grads_(batch_size)
  486. self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
  487. return True
  488. def _maybe_schedule_gradient_averaging(self) -> None:
  489. """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
  490. assert self.use_gradient_averaging
  491. if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
  492. if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
  493. eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
  494. eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
  495. logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
  496. self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
  497. def _maybe_schedule_state_averaging(self) -> None:
  498. """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
  499. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  500. if next_epoch % self.average_state_every != 0:
  501. return # averaging is not performed at this epoch
  502. if self.state_averager.averaging_in_progress:
  503. return # previous run is still in progress
  504. if self.delay_before_state_averaging.num_updates == 0:
  505. return # not enough data to accurately pre-schedule
  506. estimated_time = self.tracker.estimated_next_update_time
  507. estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
  508. estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
  509. eta_seconds_to_averaging = estimated_time - get_dht_time()
  510. if eta_seconds_to_averaging <= self.matchmaking_time:
  511. if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
  512. min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
  513. actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
  514. logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
  515. self.scheduled_state = self.state_averager.schedule_step(
  516. gather=next_epoch, timeout=self.averaging_timeout
  517. )
  518. def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
  519. """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
  520. assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
  521. averaged_gradients = False
  522. try:
  523. if maybe_step_control is not None:
  524. group_info = maybe_step_control.result(self.averaging_timeout)
  525. logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
  526. self._load_averaged_gradients_into_optimizer_()
  527. averaged_gradients = True
  528. else:
  529. logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
  530. except BaseException as e:
  531. logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
  532. if not averaged_gradients:
  533. self._load_local_gradients_into_optimizer()
  534. def _load_averaged_gradients_into_optimizer_(self):
  535. """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
  536. assert self.use_gradient_averaging
  537. if self.offload_optimizer:
  538. pass # averaged gradients are already baked into optimizer, see _make_gradient_averager
  539. else:
  540. # copy averaged gradients into optimizer .grad buffers
  541. optimized_param_groups = self.state_averager.optimizer.param_groups
  542. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  543. with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
  544. assert len(averaged_gradients) == len(optimized_parameters)
  545. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  546. opt_param.grad.copy_(averaged_grad, non_blocking=True)
  547. self.grad_averager.notify_used_averaged_gradients()
  548. def _load_local_gradients_into_optimizer(self):
  549. """Fallback to using local gradients in the optimizer (instead of averaged gradients)"""
  550. logger.log(self.status_loglevel, f"Proceeding with local gradients")
  551. self.grad_averager.load_accumulators_into_averager_()
  552. # note: we load gradients into grad_averager even though there is only one peer because of two reasons:
  553. # - if offload_optimizer, then we must load gradients onto the CPU gradient buffers used by the optimizer
  554. # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
  555. self._load_averaged_gradients_into_optimizer_()
  556. def zero_grad(self, set_to_none: bool = False):
  557. """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
  558. if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
  559. raise ValueError(
  560. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  561. f"call zero_grad manually. Gradients will be refreshed internally"
  562. )
  563. for param_group in self.param_groups:
  564. for param in param_group["params"]:
  565. if param.grad is None:
  566. pass
  567. elif set_to_none:
  568. param.grad = None
  569. else:
  570. param.grad.zero_()
  571. def _should_load_state_from_peers(self) -> bool:
  572. """
  573. If true, peer will discard local progress and attempt to download state from peers.
  574. This method allows peer to continue training in two cases:
  575. - peer is on the same epoch as other collaborators - keep training normally
  576. - peer was on the same epoch and accumulated some grads, but some collaborators
  577. have just transitioned to the next epoch - this peer should also transition.
  578. :note: The latter case occurs due to the lack of network synchrony: the first peer that
  579. detects enough samples will transition to the next step and start counting samples anew.
  580. Some other peers may take time before they check with DHT and observe that
  581. - the global epoch is technically one epoch ahead of the current one and
  582. - the remaining (non-transitioned) peers no longer have target_batch_size between them
  583. If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
  584. """
  585. if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
  586. self._should_check_synchronization_on_update = False
  587. return self.local_epoch != self.tracker.global_epoch # require exact synchronization once per step
  588. return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch
  589. def is_synchronized_with_peers(self) -> bool:
  590. """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number."""
  591. return self.local_epoch >= self.tracker.global_epoch - 1
  592. def load_state_from_peers(self, **kwargs):
  593. """
  594. Attempt to load the newest collaboration state from other peers within the same run_id.
  595. If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
  596. """
  597. # note: we tag along for the next all-reduce because the run may have already started and cancelling it
  598. # will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
  599. if self.scheduled_grads is not None and not self.scheduled_grads.done():
  600. self._tag_along_with_zero_weight(self.scheduled_grads)
  601. self.scheduled_grads = None
  602. self.state_averager.step(wait_for_delayed_updates=True)
  603. with self.tracker.pause_updates():
  604. while True:
  605. try:
  606. self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  607. break
  608. except KeyboardInterrupt:
  609. raise
  610. except BaseException as e:
  611. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  612. continue
  613. if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
  614. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
  615. self.state_averager.local_epoch = self.tracker.global_epoch
  616. self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
  617. if not self.client_mode:
  618. self.state_averager.state_sharing_priority = self.local_epoch
  619. if self.use_gradient_averaging:
  620. self.grad_averager.reset_accumulated_grads_()
  621. if not self.client_mode:
  622. self.grad_averager.state_sharing_priority = self.local_epoch
  623. def state_dict(self) -> dict:
  624. state_dict = self.state_averager.optimizer.state_dict()
  625. state_dict["state"]["local_epoch"] = self.local_epoch
  626. return state_dict
  627. def load_state_dict(self, state_dict: dict):
  628. if "local_epoch" in state_dict["state"]:
  629. self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
  630. return self.state_averager.optimizer.load_state_dict(state_dict)
  631. @property
  632. def state(self):
  633. return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
  634. @property
  635. def opt(self) -> TorchOptimizer:
  636. return self.state_averager.optimizer
  637. @property
  638. def param_groups(self) -> ParamGroups:
  639. next_index = 0
  640. param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
  641. for param_group in param_groups:
  642. num_params = len(param_group["params"])
  643. main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
  644. param_group["params"] = main_params_for_group
  645. next_index += num_params
  646. assert next_index == len(self.state_averager.main_parameters)
  647. return param_groups
  648. def add_param_group(self, param_group: dict) -> None:
  649. raise ValueError(
  650. f"{self.__class__.__name__} does not support calling add_param_group after creation. "
  651. f"Please provide all parameter groups at init"
  652. )
  653. def __repr__(self):
  654. return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
  655. def _tag_along_with_zero_weight(self, control: StepControl):
  656. """Wait for a running averaging round to finish with zero weight."""
  657. if not control.triggered:
  658. control.weight = 0
  659. control.allow_allreduce()
  660. if not control.done():
  661. try:
  662. control.result(self.averaging_timeout)
  663. except BaseException as e:
  664. logger.exception(e)
  665. if not control.done():
  666. control.cancel()
  667. def shutdown(self):
  668. logger.log(self.status_loglevel, "Sending goodbye to peers...")
  669. self.tracker.shutdown(self.shutdown_timeout)
  670. self.state_averager.step(wait_for_delayed_updates=True)
  671. for scheduled_round in self.scheduled_grads, self.scheduled_state:
  672. if scheduled_round is not None:
  673. if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
  674. scheduled_round.cancel()
  675. else:
  676. self._tag_along_with_zero_weight(scheduled_round)
  677. logger.log(self.status_loglevel, "Shutting down averagers...")
  678. self.state_averager.shutdown()
  679. if self.use_gradient_averaging:
  680. self.grad_averager.shutdown()
  681. logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down")
  682. def __del__(self):
  683. if self._parent_pid == os.getpid() and self.is_alive():
  684. self.shutdown()