optimizer.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import time
  5. from functools import partial
  6. from typing import Callable, Optional, Sequence, Union
  7. import torch
  8. from hivemind.averaging.control import AveragingStage, StepControl
  9. from hivemind.compression import CompressionBase, NoCompression
  10. from hivemind.dht import DHT
  11. from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
  12. from hivemind.optim.grad_scaler import GradScaler
  13. from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
  14. from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
  15. from hivemind.optim.state_averager import (
  16. LRSchedulerBase,
  17. OptimizerFactory,
  18. Parameters,
  19. ParamGroups,
  20. SchedulerFactory,
  21. TorchOptimizer,
  22. TrainingStateAverager,
  23. )
  24. from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
  25. logger = get_logger(__name__)
  26. class Optimizer(torch.optim.Optimizer):
  27. """
  28. hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
  29. By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
  30. There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
  31. or even fully asynchronous (use_local_updates=True).
  32. :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
  33. >>> model = transformers.AutoModel("albert-xxlarge-v2")
  34. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
  35. >>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
  36. >>> params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
  37. >>> while True:
  38. >>> loss = compute_loss_on_batch(model, batch_size=4)
  39. >>> opt.zero_grad()
  40. >>> loss.backward()
  41. >>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
  42. By default, peers will perform the following steps:
  43. * accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
  44. * after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;
  45. * if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;
  46. Unlike regular training, your device may join midway through training, when other peers already made some progress.
  47. For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
  48. ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
  49. may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
  50. :What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
  51. coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
  52. Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
  53. At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
  54. updating the learning rate scheduler or simply averaging parameters (if using local updates).
  55. The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters.
  56. For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.
  57. :Configuration guide: This guide will help you set up your first collaborative training run. It covers the most
  58. important basic options, but ignores features that require significant changes to the training code.
  59. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
  60. >>> opt = hivemind.Optimizer(
  61. >>> dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
  62. >>> batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
  63. >>> # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
  64. >>> # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
  65. >>>
  66. >>> params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
  67. >>> # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
  68. >>> scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
  69. >>> # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
  70. >>>
  71. >>> offload_optimizer=True, # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
  72. >>> delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
  73. >>> # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
  74. >>>
  75. >>> grad_compression=hivemind.Float16Compression(), state_averaging_compression=hivemind.Float16Compression(),
  76. >>> # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
  77. >>> # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
  78. >>>
  79. >>> matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
  80. >>> averaging_timeout=60.0, # around of 2x the actual time it takes to run all-reduce
  81. >>> verbose=True # periodically report the training progress to the console (e.g. "Averaged with N peers")
  82. >>> ) # and you're done!
  83. :param dht: a running hivemind.DHT instance connected to other peers.
  84. :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
  85. **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
  86. Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
  87. ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
  88. individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
  89. :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch.
  90. The actual batch may be *slightly* larger due asynchrony (e.g. peers submit more gradients in the last second).
  91. :param batch_size_per_step: you should accumulate gradients over this many samples between calls to optimizer.step.
  92. :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params).
  93. :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer.
  94. **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require
  95. and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.
  96. :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
  97. The learning rate scheduler will adjust learning rate based on global epoch, not the number of
  98. local calls to optimizer.step; this is required to keep different peers synchronized.
  99. :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds.
  100. Increase if you see "averaged gradients with N peers" where N is below 0.9x the real siee on >=25% of epochs.
  101. When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
  102. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
  103. Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
  104. Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
  105. :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
  106. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
  107. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  108. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  109. :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
  110. :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
  111. :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
  112. :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
  113. if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
  114. :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
  115. This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
  116. The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
  117. hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
  118. gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
  119. :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
  120. if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
  121. Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
  122. :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
  123. :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
  124. :param grad_compression: compression strategy used for averaging gradients, default = no compression
  125. :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
  126. :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
  127. :param load_state_compression: compression strategy for loading state from peers, default = no compression
  128. :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
  129. :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
  130. :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
  131. :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
  132. :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer
  133. :param verbose: if True, report internal events such as accumilating gradients and running background tasks
  134. :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
  135. is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
  136. """
  137. def __init__(
  138. self,
  139. *,
  140. dht: DHT,
  141. run_id: str,
  142. target_batch_size: int,
  143. batch_size_per_step: Optional[int] = None,
  144. optimizer: Union[TorchOptimizer, OptimizerFactory],
  145. params: Optional[Union[Parameters, ParamGroups]] = None,
  146. scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
  147. matchmaking_time: Optional[float] = 15.0,
  148. averaging_timeout: Optional[float] = 60.0,
  149. allreduce_timeout: Optional[float] = None,
  150. next_chunk_timeout: Optional[float] = None,
  151. load_state_timeout: float = 600.0,
  152. reuse_grad_buffers: bool = False,
  153. offload_optimizer: Optional[bool] = None,
  154. delay_optimizer_step: Optional[bool] = None,
  155. delay_grad_averaging: bool = False,
  156. delay_state_averaging: bool = True,
  157. average_state_every: int = 1,
  158. use_local_updates: bool = False,
  159. client_mode: bool = None,
  160. auxiliary: bool = False,
  161. grad_compression: CompressionBase = NoCompression(),
  162. grad_averager_factory: Optional[GradientAveragerFactory] = GradientAverager,
  163. state_averaging_compression: CompressionBase = NoCompression(),
  164. load_state_compression: CompressionBase = NoCompression(),
  165. average_opt_statistics: Sequence[str] = (),
  166. extra_tensors: Sequence[torch.Tensor] = (),
  167. averager_opts: Optional[dict] = None,
  168. tracker_opts: Optional[dict] = None,
  169. performance_ema_alpha: float = 0.1,
  170. shutdown_timeout: float = 5,
  171. verbose: bool = False,
  172. ):
  173. self._parent_pid = os.getpid()
  174. client_mode = client_mode if client_mode is None else dht.client_mode
  175. delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
  176. if offload_optimizer is None:
  177. offload_optimizer = params is not None and not use_local_updates
  178. allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
  179. next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
  180. assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
  181. assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
  182. assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
  183. if callable(optimizer) and params is not None:
  184. if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
  185. raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
  186. elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
  187. if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
  188. raise ValueError(
  189. "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
  190. "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
  191. )
  192. else:
  193. raise ValueError(
  194. "Please initialize the optimizer in one of the following two ways:\n"
  195. "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
  196. "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
  197. )
  198. if use_local_updates:
  199. assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
  200. assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
  201. assert (
  202. grad_averager_factory is None
  203. ), "if local_updates is True, provided grad_averager_factory will not be used"
  204. self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
  205. self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
  206. self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
  207. self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
  208. self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
  209. self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
  210. self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
  211. self.next_chunk_timeout = next_chunk_timeout
  212. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  213. self.scheduled_grads: Optional[StepControl] = None
  214. self.scheduled_state: Optional[StepControl] = None
  215. self.tracker = self._make_progress_tracker(
  216. target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
  217. )
  218. self.state_averager = self._make_state_averager(
  219. optimizer=optimizer,
  220. params=params,
  221. scheduler=scheduler,
  222. delta_rule_averaging=use_local_updates and self.delay_state_averaging,
  223. compression=state_averaging_compression,
  224. state_compression=load_state_compression,
  225. average_opt_statistics=average_opt_statistics,
  226. performance_ema_alpha=performance_ema_alpha,
  227. extra_tensors=extra_tensors,
  228. **averager_opts or {},
  229. )
  230. if grad_averager_factory is not None and not use_local_updates:
  231. self.grad_averager = self._make_gradient_averager(
  232. reuse_grad_buffers=reuse_grad_buffers, grad_averager_factory=grad_averager_factory
  233. )
  234. else:
  235. self.grad_averager = None
  236. self._should_check_synchronization_on_update = True # used in self.should_load_state_from_peers
  237. self._schema_hash = self._compute_schema_hash()
  238. self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
  239. # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
  240. # used for pre-scheduling the averaging round in state_averager
  241. self._step_supports_amp_scaling = reuse_grad_buffers
  242. # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
  243. # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
  244. def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
  245. return TrainingStateAverager(
  246. dht=self.dht,
  247. prefix=f"{self.run_id}_state_averager",
  248. min_matchmaking_time=self.matchmaking_time,
  249. allreduce_timeout=self.allreduce_timeout,
  250. shutdown_timeout=self.shutdown_timeout,
  251. offload_optimizer=self.offload_optimizer,
  252. custom_gradients=self.offload_optimizer,
  253. status_loglevel=self.status_loglevel,
  254. next_chunk_timeout=self.next_chunk_timeout,
  255. client_mode=self.client_mode,
  256. auxiliary=self.auxiliary,
  257. start=True,
  258. **kwargs,
  259. )
  260. def _make_gradient_averager(self, grad_averager_factory, **kwargs) -> GradientAverager:
  261. assert hasattr(self, "state_averager"), "must initialize state averager first"
  262. grad_averager = grad_averager_factory(
  263. dht=self.dht,
  264. prefix=f"{self.run_id}_grad_averager",
  265. parameters=self.state_averager.main_parameters,
  266. min_matchmaking_time=self.matchmaking_time,
  267. allreduce_timeout=self.allreduce_timeout,
  268. shutdown_timeout=self.shutdown_timeout,
  269. next_chunk_timeout=self.next_chunk_timeout,
  270. client_mode=self.client_mode,
  271. auxiliary=self.auxiliary,
  272. start=True,
  273. **kwargs,
  274. )
  275. if self.offload_optimizer:
  276. optimized_param_groups = self.state_averager.optimizer.param_groups
  277. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  278. with grad_averager.get_tensors() as averaged_gradients:
  279. assert len(averaged_gradients) == len(optimized_parameters)
  280. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  281. opt_param.grad = averaged_grad
  282. return grad_averager
  283. def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
  284. return ProgressTracker(
  285. dht=self.dht,
  286. prefix=self.run_id,
  287. target_batch_size=target_batch_size,
  288. client_mode=self.client_mode,
  289. status_loglevel=self.status_loglevel,
  290. start=True,
  291. **kwargs,
  292. )
  293. def _compute_schema_hash(self) -> int:
  294. optimized_param_groups = self.state_averager.optimizer.param_groups
  295. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  296. param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
  297. # offloaded optimizer requires that gradient tensors are reused between iterations
  298. grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
  299. return hash((grad_ids, param_shapes))
  300. def is_alive(self) -> bool:
  301. return self.state_averager.is_alive()
  302. @property
  303. def local_epoch(self) -> int:
  304. """
  305. This worker's current epoch, kept synchronized with peers. If peer's local_epoch lags behind others, it will
  306. automatically re-synchronize by downloading state from another peer.
  307. An epoch corresponds to accumulating target_batch_size across all active devices.
  308. """
  309. return self.state_averager.local_epoch
  310. @property
  311. def local_progress(self) -> LocalTrainingProgress:
  312. return self.tracker.local_progress
  313. @property
  314. def use_local_updates(self) -> bool:
  315. return self.grad_averager is None
  316. @property
  317. def use_gradient_averaging(self) -> bool:
  318. return self.grad_averager is not None
  319. def step(
  320. self,
  321. closure: Optional[Callable[[], torch.Tensor]] = None,
  322. batch_size: Optional[int] = None,
  323. grad_scaler: Optional[GradScaler] = None,
  324. ):
  325. """
  326. Update training progress after accumulating another local batch size. Depending on the configuration, this will
  327. report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
  328. :param closure: A closure that reevaluates the model and returns the loss.
  329. :param batch_size: optional override for batch_size_per_step from init.
  330. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler.
  331. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  332. """
  333. if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
  334. raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
  335. if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
  336. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  337. if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
  338. raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
  339. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  340. # if delayed updates finished before step, apply these updates; otherwise do nothing
  341. self.state_averager.step(apply_delayed_updates=True)
  342. loss = None
  343. if closure is not None:
  344. with torch.enable_grad():
  345. loss = closure()
  346. if not self.auxiliary and self._should_load_state_from_peers():
  347. logger.log(self.status_loglevel, "Peer is out of sync")
  348. self.load_state_from_peers()
  349. return loss # local gradients were computed with out-of-sync parameters, must start over
  350. if self.use_gradient_averaging:
  351. # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
  352. if not self.auxiliary:
  353. grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
  354. if not grads_are_valid:
  355. return loss # local gradients were reset due to overflow, must start over
  356. self._maybe_schedule_gradient_averaging()
  357. self._maybe_schedule_state_averaging()
  358. else:
  359. # use_local_updates=True: update parameters on every step independently of other peers
  360. if not self.auxiliary:
  361. if grad_scaler is not None:
  362. with grad_scaler.running_global_step():
  363. assert grad_scaler.unscale_(self)
  364. new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
  365. self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
  366. self._maybe_schedule_state_averaging()
  367. self.state_averager.step(
  368. increment_epoch=False,
  369. optimizer_step=True,
  370. delay_optimizer_step=self.delay_optimizer_step,
  371. grad_scaler=grad_scaler,
  372. )
  373. if self.tracker.ready_to_update_epoch:
  374. self._update_global_epoch(grad_scaler)
  375. return loss
  376. def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
  377. """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
  378. assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
  379. _epoch_start_time = time.perf_counter()
  380. with self.tracker.pause_updates():
  381. wait_for_trigger = None
  382. if self.use_gradient_averaging:
  383. logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
  384. if self.delay_optimizer_step:
  385. self.state_averager.step(wait_for_delayed_updates=True)
  386. began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
  387. if not began_averaging_gradients:
  388. # failed to start gradient averaging due to an internal error
  389. self.grad_averager.load_accumulators_into_averager_()
  390. elif self.delay_grad_averaging:
  391. # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
  392. wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
  393. else:
  394. # delay_grad_averaging=False, average gradients immediately
  395. self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
  396. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  397. swarm_not_empty = self.tracker.global_progress.num_peers > 1
  398. should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
  399. should_average_state = (
  400. swarm_not_empty
  401. and next_epoch % self.average_state_every == 0
  402. and not self.state_averager.averaging_in_progress
  403. )
  404. if should_average_state and self.scheduled_state is not None:
  405. if self.scheduled_state.triggered or self.scheduled_state.done():
  406. logger.log(
  407. self.status_loglevel,
  408. f"Not using pre-scheduled group for state averaging because it"
  409. f"was already used elsewhere: {self.scheduled_state}",
  410. )
  411. self.scheduled_state = None
  412. self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
  413. self.state_averager.step(
  414. increment_epoch=True,
  415. wait_for_trigger=wait_for_trigger,
  416. optimizer_step=should_perform_optimizer_step,
  417. delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
  418. grad_scaler=grad_scaler,
  419. averaging_round=should_average_state,
  420. delay_averaging=self.delay_state_averaging and not self.auxiliary,
  421. averaging_control=self.scheduled_state if should_average_state else None,
  422. averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
  423. )
  424. if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
  425. self.scheduled_state.cancel()
  426. self.scheduled_state = None
  427. self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
  428. self._should_check_synchronization_on_update = True
  429. # the above line ensures that peers check for *strict* synchronization once per epoch
  430. if not self.client_mode:
  431. self.state_averager.state_sharing_priority = self.local_epoch
  432. if self.use_gradient_averaging and not self.auxiliary:
  433. self.grad_averager.reset_accumulated_grads_()
  434. if not self.client_mode:
  435. self.grad_averager.state_sharing_priority = self.local_epoch
  436. logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}")
  437. def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
  438. """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
  439. if grad_scaler is not None:
  440. with grad_scaler.running_global_step():
  441. assert grad_scaler.unscale_(self)
  442. began_averaging_gradients = False
  443. if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
  444. logger.log(
  445. self.status_loglevel,
  446. f"Not using pre-scheduled group for state averaging because it"
  447. f"was already used elsewhere: {self.scheduled_state}",
  448. )
  449. self.scheduled_grads = None
  450. elif self.tracker.global_progress.num_peers > 1:
  451. try:
  452. self.scheduled_grads = self.grad_averager.step(
  453. control=self.scheduled_grads, reset_accumulators=True, wait=False
  454. )
  455. began_averaging_gradients = True
  456. except BaseException as e:
  457. logger.exception(e)
  458. if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
  459. if self.tracker.global_progress.num_peers > 1:
  460. logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
  461. self._tag_along_with_zero_weight(self.scheduled_grads)
  462. else:
  463. logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
  464. self._load_local_gradients_into_optimizer()
  465. self.scheduled_grads.cancel()
  466. self.scheduled_grads = None
  467. return began_averaging_gradients
  468. def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
  469. """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
  470. assert not self.use_local_updates and not self.auxiliary
  471. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  472. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  473. self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
  474. self.grad_averager.reset_accumulated_grads_()
  475. return False
  476. self.grad_averager.accumulate_grads_(batch_size)
  477. self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
  478. return True
  479. def _maybe_schedule_gradient_averaging(self) -> None:
  480. """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
  481. assert self.use_gradient_averaging
  482. if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
  483. if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
  484. eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
  485. eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
  486. logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
  487. self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
  488. def _maybe_schedule_state_averaging(self) -> None:
  489. """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
  490. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  491. if next_epoch % self.average_state_every != 0:
  492. return # averaging is not performed at this epoch
  493. if self.state_averager.averaging_in_progress:
  494. return # previous run is still in progress
  495. if self.delay_before_state_averaging.num_updates == 0:
  496. return # not enough data to accurately pre-schedule
  497. estimated_time = self.tracker.estimated_next_update_time
  498. estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
  499. estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
  500. eta_seconds_to_averaging = estimated_time - get_dht_time()
  501. if eta_seconds_to_averaging <= self.matchmaking_time:
  502. if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
  503. min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
  504. actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
  505. logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
  506. self.scheduled_state = self.state_averager.schedule_step(
  507. gather=next_epoch, timeout=self.averaging_timeout
  508. )
  509. def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
  510. """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
  511. assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
  512. averaged_gradients = False
  513. try:
  514. if maybe_step_control is not None:
  515. group_info = maybe_step_control.result(self.averaging_timeout)
  516. logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
  517. self._load_averaged_gradients_into_optimizer_()
  518. averaged_gradients = True
  519. else:
  520. logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
  521. except BaseException as e:
  522. logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
  523. if not averaged_gradients:
  524. self._load_local_gradients_into_optimizer()
  525. def _load_averaged_gradients_into_optimizer_(self):
  526. """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
  527. assert self.use_gradient_averaging
  528. if self.offload_optimizer:
  529. pass # averaged gradients are already baked into optimizer, see _make_gradient_averager
  530. else:
  531. # copy averaged gradients into optimizer .grad buffers
  532. optimized_param_groups = self.state_averager.optimizer.param_groups
  533. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  534. with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
  535. assert len(averaged_gradients) == len(optimized_parameters)
  536. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  537. opt_param.grad.copy_(averaged_grad, non_blocking=True)
  538. self.grad_averager.notify_used_averaged_gradients()
  539. def _load_local_gradients_into_optimizer(self):
  540. """Fallback to using local gradients in the optimizer (instead of averaged gradients)"""
  541. logger.log(self.status_loglevel, f"Proceeding with local gradients")
  542. self.grad_averager.load_accumulators_into_averager_()
  543. # note: we load gradients into grad_averager even though there is only one peer because of two reasons:
  544. # - if offload_optimizer, then we must load gradients onto the CPU gradient buffers used by the optimizer
  545. # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
  546. self._load_averaged_gradients_into_optimizer_()
  547. def zero_grad(self, set_to_none: bool = False):
  548. """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
  549. if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
  550. raise ValueError(
  551. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  552. f"call zero_grad manually. Gradients will be refreshed internally"
  553. )
  554. for param_group in self.param_groups:
  555. for param in param_group["params"]:
  556. if param.grad is None:
  557. pass
  558. elif set_to_none:
  559. param.grad = None
  560. else:
  561. param.grad.zero_()
  562. def _should_load_state_from_peers(self) -> bool:
  563. """
  564. If true, peer will discard local progress and attempt to download state from peers.
  565. This method allows peer to continue training in two cases:
  566. - peer is on the same epoch as other collaborators - keep training normally
  567. - peer was on the same epoch and accumulated some grads, but some collaborators
  568. have just transitioned to the next epoch - this peer should also transition.
  569. :note: The latter case occurs due to the lack of network synchrony: the first peer that
  570. detects enough samples will transition to the next step and start counting samples anew.
  571. Some other peers may take time before they check with DHT and observe that
  572. - the global epoch is technically one epoch ahead of the current one and
  573. - the remaining (non-transitioned) peers no longer have target_batch_size between them
  574. If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
  575. """
  576. if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
  577. self._should_check_synchronization_on_update = False
  578. return self.local_epoch != self.tracker.global_epoch # require exact synchronization once per step
  579. return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch
  580. def is_synchronized_with_peers(self) -> bool:
  581. """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number."""
  582. return self.local_epoch >= self.tracker.global_epoch - 1
  583. def load_state_from_peers(self, **kwargs):
  584. """
  585. Attempt to load the newest collaboration state from other peers within the same run_id.
  586. If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
  587. """
  588. # note: we tag along for the next all-reduce because the run may have already started and cancelling it
  589. # will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
  590. if self.scheduled_grads is not None and not self.scheduled_grads.done():
  591. self._tag_along_with_zero_weight(self.scheduled_grads)
  592. self.scheduled_grads = None
  593. self.state_averager.step(wait_for_delayed_updates=True)
  594. with self.tracker.pause_updates():
  595. while True:
  596. try:
  597. self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  598. self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  599. break
  600. except KeyboardInterrupt:
  601. raise
  602. except BaseException as e:
  603. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  604. continue
  605. if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
  606. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
  607. self.state_averager.local_epoch = self.tracker.global_epoch
  608. self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
  609. if not self.client_mode:
  610. self.state_averager.state_sharing_priority = self.local_epoch
  611. if self.use_gradient_averaging:
  612. self.grad_averager.reset_accumulated_grads_()
  613. if not self.client_mode:
  614. self.grad_averager.state_sharing_priority = self.local_epoch
  615. def state_dict(self) -> dict:
  616. state_dict = self.state_averager.optimizer.state_dict()
  617. state_dict["state"]["local_epoch"] = self.local_epoch
  618. return state_dict
  619. def load_state_dict(self, state_dict: dict):
  620. if "local_epoch" in state_dict["state"]:
  621. self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
  622. return self.state_averager.optimizer.load_state_dict(state_dict)
  623. @property
  624. def state(self):
  625. return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
  626. @property
  627. def opt(self) -> TorchOptimizer:
  628. return self.state_averager.optimizer
  629. @property
  630. def param_groups(self) -> ParamGroups:
  631. next_index = 0
  632. param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
  633. for param_group in param_groups:
  634. num_params = len(param_group["params"])
  635. main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
  636. param_group["params"] = main_params_for_group
  637. next_index += num_params
  638. assert next_index == len(self.state_averager.main_parameters)
  639. return param_groups
  640. def add_param_group(self, param_group: dict) -> None:
  641. raise ValueError(
  642. f"{self.__class__.__name__} does not support calling add_param_group after creation. "
  643. f"Please provide all parameter groups at init"
  644. )
  645. def __repr__(self):
  646. return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
  647. def _tag_along_with_zero_weight(self, control: StepControl):
  648. """Wait for a running averaging round to finish with zero weight."""
  649. if not control.triggered:
  650. control.weight = 0
  651. control.allow_allreduce()
  652. if not control.done():
  653. try:
  654. control.result(self.averaging_timeout)
  655. except BaseException as e:
  656. logger.exception(e)
  657. if not control.done():
  658. control.cancel()
  659. def shutdown(self):
  660. logger.log(self.status_loglevel, "Sending goodbye to peers...")
  661. self.tracker.shutdown(self.shutdown_timeout)
  662. self.state_averager.step(wait_for_delayed_updates=True)
  663. for scheduled_round in self.scheduled_grads, self.scheduled_state:
  664. if scheduled_round is not None:
  665. if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
  666. scheduled_round.cancel()
  667. else:
  668. self._tag_along_with_zero_weight(scheduled_round)
  669. logger.log(self.status_loglevel, "Shutting down averagers...")
  670. self.state_averager.shutdown()
  671. if self.use_gradient_averaging:
  672. self.grad_averager.shutdown()
  673. logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down")
  674. def __del__(self):
  675. if self._parent_pid == os.getpid() and self.is_alive():
  676. self.shutdown()