optimizer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import time
  5. from functools import partial
  6. from typing import Callable, Optional, Sequence, Union
  7. import torch
  8. from hivemind.averaging.control import StepControl
  9. from hivemind.compression import CompressionBase, NoCompression
  10. from hivemind.dht import DHT
  11. from hivemind.optim.experimental.grad_averager import GradientAverager
  12. from hivemind.optim.experimental.progress_tracker import ProgressTracker
  13. from hivemind.optim.experimental.state_averager import (
  14. LRSchedulerBase,
  15. OptimizerFactory,
  16. Parameters,
  17. ParamGroups,
  18. SchedulerFactory,
  19. TorchOptimizer,
  20. TrainingStateAverager,
  21. )
  22. from hivemind.optim.grad_scaler import GradScaler
  23. from hivemind.utils import get_dht_time, get_logger, PerformanceEMA
  24. logger = get_logger(__name__)
  25. class Optimizer(torch.optim.Optimizer):
  26. """
  27. Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
  28. By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
  29. There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
  30. or even fully asynchronous (local_updates=True). However, these options require careful tuning.
  31. The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
  32. >>> model = transformers.AutoModel("albert-xxlarge-v2")
  33. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
  34. >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam, params=model.parameters(),
  35. >>> target_batch_size=4096, batch_size_per_step=4) # recommended way to create Optimizer
  36. >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
  37. >>> while True:
  38. >>> loss = compute_loss_on_batch(model, batch_size=4)
  39. >>> opt.zero_grad()
  40. >>> loss.backward()
  41. >>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
  42. However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
  43. - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
  44. - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
  45. - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
  46. :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
  47. learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
  48. (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
  49. when others have already made some progress and changed their learning rates accordingly.
  50. :param dht: a running hivemind.DHT instance connected to other peers
  51. :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys
  52. :note: peers with the same run_id should *generally* train the same model and use the same optimizer configuration.
  53. Some options can be safely changed by individual peers: `batch_size_per_step`, `client_mode`, `auxiliary`,
  54. `reuse_grad_buffers`, `offload_optimizer`, and `verbose`. In some cases, other options may also be tuned
  55. individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
  56. :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
  57. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  58. :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
  59. :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
  60. :note: creating hivemind.Optimizer with params=model.parameters() and optimizer=lambda params: make_optim(params)
  61. is required for advanced options: offload_optimizer, delay_optimizer_step and delay_grad_averaging.
  62. :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler
  63. :note: the learning rate scheduler will adjust learning rate based on collaboration-wide epoch, not the number of
  64. local calls to optimizer.step; this is required to keep different peers synchronized.
  65. :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
  66. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  67. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
  68. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  69. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  70. :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
  71. :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
  72. :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
  73. :note: offload_optimizer, delay_optimizer_step and delay_grad_averaging require that the optimizer is
  74. created as follows: `hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())`
  75. :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
  76. if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
  77. :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many **epochs**
  78. This reduces the communication overhead increasing, but can cause parameters to diverge if too large
  79. :note: The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
  80. hardly ever skip averaging rounds, they can average state less frequently. Network failures, lossy gradient
  81. compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
  82. :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
  83. if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients
  84. :note: even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
  85. :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
  86. :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
  87. :note: client_mode=True and auxiliary=True are mutually exclusive; auxiliary also requires batch_size_per_step=None
  88. :param grad_compression: compression strategy used for averaging gradients, default = no compression
  89. :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
  90. :param load_state_compression: compression strategy for loading state from peers, default = no compression
  91. :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
  92. :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
  93. :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
  94. :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
  95. :param performance_ema_alpha: moving average alpha in ProgressTracer, TrainingStateAverager and Optimizer
  96. :param verbose: if True, report internal events such as accumilating gradients and running background tasks
  97. Internally, hivemind.Optimizer consists of 4 components:
  98. - DHT, a decentralized key-value storage used for coordination across the swarm
  99. - GradientAverager that is responsible for aggregating gradients with peers for global steps (can be disabled)
  100. - TrainingStateAverager holds parameters and optimizer/scheduler statistics, keeping them weakly synchronized
  101. by averaging with peers. It can also download these variable from other peers if your peer is out of sync.
  102. - ProgressTracker that uses DHT to track the global training progress: the number of steps or samples accumulated
  103. """
  104. def __init__(
  105. self,
  106. *,
  107. dht: DHT,
  108. run_id: str,
  109. target_batch_size: int,
  110. batch_size_per_step: Optional[int] = None,
  111. optimizer: Union[TorchOptimizer, OptimizerFactory],
  112. params: Optional[Union[Parameters, ParamGroups]] = None,
  113. scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
  114. matchmaking_time: Optional[float] = 15.0,
  115. averaging_timeout: Optional[float] = 300.0,
  116. load_state_timeout: float = 600.0,
  117. reuse_grad_buffers: bool = False,
  118. offload_optimizer: Optional[bool] = None,
  119. delay_optimizer_step: Optional[bool] = None,
  120. delay_grad_averaging: bool = False,
  121. delay_state_averaging: bool = True,
  122. average_state_every: int = 1,
  123. use_local_updates: bool = False,
  124. client_mode: bool = None,
  125. auxiliary: bool = False,
  126. grad_compression: CompressionBase = NoCompression(),
  127. state_averaging_compression: CompressionBase = NoCompression(),
  128. load_state_compression: CompressionBase = NoCompression(),
  129. average_opt_statistics: Sequence[str] = (),
  130. extra_tensors: Sequence[torch.Tensor] = (),
  131. averager_opts: Optional[dict] = None,
  132. tracker_opts: Optional[dict] = None,
  133. performance_ema_alpha: float = 0.1,
  134. shutdown_timeout: float = 5,
  135. verbose: bool = False,
  136. ):
  137. client_mode = client_mode if client_mode is None else dht.client_mode
  138. delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
  139. offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
  140. assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
  141. assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
  142. assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
  143. if callable(optimizer) and params is not None:
  144. if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
  145. raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
  146. elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
  147. if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
  148. raise ValueError(
  149. "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
  150. "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
  151. )
  152. else:
  153. raise ValueError(
  154. "Please initialize the optimizer in one of the following two ways:\n"
  155. "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
  156. "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
  157. )
  158. if use_local_updates:
  159. assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
  160. assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
  161. self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
  162. self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
  163. self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
  164. self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
  165. self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
  166. self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
  167. self.shutdown_timeout = shutdown_timeout
  168. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  169. self.scheduled_grads: Optional[StepControl] = None
  170. self.scheduled_state: Optional[StepControl] = None
  171. self.tracker = self._make_progress_tracker(
  172. target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
  173. )
  174. self.state_averager = self._make_state_averager(
  175. optimizer=optimizer,
  176. params=params,
  177. scheduler=scheduler,
  178. delta_rule_averaging=use_local_updates and self.delay_state_averaging,
  179. compression=state_averaging_compression,
  180. state_compression=load_state_compression,
  181. average_opt_statistics=average_opt_statistics,
  182. performance_ema_alpha=performance_ema_alpha,
  183. extra_tensors=extra_tensors,
  184. **averager_opts or {},
  185. )
  186. if not use_local_updates:
  187. self.grad_averager = self._make_gradient_averager(
  188. reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
  189. )
  190. else:
  191. self.grad_averager = None
  192. self._should_check_synchronization_on_update = True # used in self.should_load_state_from_peers
  193. self._schema_hash = self._compute_schema_hash()
  194. self._parent_pid = os.getpid()
  195. self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
  196. # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
  197. # used for pre-scheduling the averaging round in state_averager
  198. self._step_supports_amp_scaling = reuse_grad_buffers
  199. # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
  200. # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
  201. def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
  202. return TrainingStateAverager(
  203. dht=self.dht,
  204. prefix=f"{self.run_id}_state_averager",
  205. allreduce_timeout=self.averaging_timeout,
  206. shutdown_timeout=self.shutdown_timeout,
  207. offload_optimizer=self.offload_optimizer,
  208. custom_gradients=self.offload_optimizer,
  209. status_loglevel=self.status_loglevel,
  210. client_mode=self.client_mode,
  211. auxiliary=self.auxiliary,
  212. start=True,
  213. **kwargs,
  214. )
  215. def _make_gradient_averager(self, **kwargs) -> GradientAverager:
  216. assert hasattr(self, "state_averager"), "must initialize state averager first"
  217. grad_averager = GradientAverager(
  218. dht=self.dht,
  219. prefix=f"{self.run_id}_grad_averager",
  220. parameters=self.state_averager.main_parameters,
  221. allreduce_timeout=self.averaging_timeout,
  222. shutdown_timeout=self.shutdown_timeout,
  223. client_mode=self.client_mode,
  224. auxiliary=self.auxiliary,
  225. start=True,
  226. **kwargs,
  227. )
  228. if self.offload_optimizer:
  229. optimized_param_groups = self.state_averager.optimizer.param_groups
  230. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  231. with grad_averager.get_tensors() as averaged_gradients:
  232. assert len(averaged_gradients) == len(optimized_parameters)
  233. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  234. opt_param.grad = averaged_grad
  235. return grad_averager
  236. def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
  237. return ProgressTracker(
  238. dht=self.dht,
  239. prefix=self.run_id,
  240. target_batch_size=target_batch_size,
  241. client_mode=self.client_mode,
  242. status_loglevel=self.status_loglevel,
  243. start=True,
  244. **kwargs,
  245. )
  246. def _compute_schema_hash(self) -> int:
  247. optimized_param_groups = self.state_averager.optimizer.param_groups
  248. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  249. param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
  250. # offloaded optimizer requires that gradient tensors are reused between iterations
  251. grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
  252. return hash((grad_ids, param_shapes))
  253. def is_alive(self) -> bool:
  254. return self.state_averager.is_alive()
  255. @property
  256. def local_epoch(self) -> int:
  257. return self.state_averager.local_epoch
  258. @property
  259. def use_local_updates(self) -> bool:
  260. return self.grad_averager is None
  261. @property
  262. def use_gradient_averaging(self) -> bool:
  263. return self.grad_averager is not None
  264. def step(
  265. self,
  266. closure: Optional[Callable[[], torch.Tensor]] = None,
  267. batch_size: Optional[int] = None,
  268. grad_scaler: Optional[GradScaler] = None,
  269. ):
  270. """
  271. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  272. :param closure: A closure that reevaluates the model and returns the loss
  273. :param batch_size: optional override for batch_size_per_step from init
  274. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
  275. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  276. """
  277. if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
  278. raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
  279. if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
  280. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  281. if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
  282. raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
  283. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  284. # if delayed updates finished before step, apply these updates; otherwise do nothing
  285. self.state_averager.step(apply_delayed_updates=True)
  286. loss = None
  287. if closure is not None:
  288. with torch.enable_grad():
  289. loss = closure()
  290. if not self.auxiliary and self.should_load_state_from_peers():
  291. logger.log(self.status_loglevel, "Peer is out of sync.")
  292. self.load_state_from_peers()
  293. return loss # local gradients were computed with out-of-sync parameters, must start over
  294. if self.use_gradient_averaging:
  295. # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
  296. if not self.auxiliary:
  297. grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
  298. if not grads_are_valid:
  299. return loss # local gradients were reset due to overflow, must start over
  300. self._maybe_schedule_gradient_averaging()
  301. self._maybe_schedule_state_averaging()
  302. else:
  303. # use_local_updates=True: update parameters on every step independently of other peers
  304. if not self.auxiliary:
  305. if grad_scaler is not None:
  306. with grad_scaler.running_global_step():
  307. assert grad_scaler.unscale_(self)
  308. new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
  309. self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
  310. self._maybe_schedule_state_averaging()
  311. self.state_averager.step(
  312. increment_epoch=False,
  313. optimizer_step=True,
  314. delay_optimizer_step=self.delay_optimizer_step,
  315. grad_scaler=grad_scaler,
  316. )
  317. if self.tracker.ready_to_update_epoch:
  318. self._update_global_epoch(grad_scaler)
  319. return loss
  320. def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
  321. """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
  322. assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
  323. _epoch_start_time = time.perf_counter()
  324. with self.tracker.pause_updates():
  325. wait_for_trigger = None
  326. if self.use_gradient_averaging:
  327. logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
  328. began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
  329. if not began_averaging_gradients:
  330. pass # failed to start gradient averaging due to an internal error
  331. if self.delay_grad_averaging:
  332. # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
  333. wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
  334. else:
  335. # delay_grad_averaging=False, average gradients immediately
  336. self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
  337. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  338. swarm_not_empty = self.tracker.global_progress.num_peers > 1
  339. should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
  340. should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
  341. if should_average_state:
  342. self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
  343. self.state_averager.step(
  344. increment_epoch=True,
  345. wait_for_trigger=wait_for_trigger,
  346. optimizer_step=should_perform_optimizer_step,
  347. delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
  348. grad_scaler=grad_scaler,
  349. averaging_round=should_average_state,
  350. delay_averaging=self.delay_state_averaging and not self.auxiliary,
  351. averaging_control=self.scheduled_state if should_average_state else None,
  352. averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
  353. )
  354. if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
  355. self.scheduled_state.cancel()
  356. self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
  357. self.scheduled_grads = self.scheduled_state = None
  358. self._should_check_synchronization_on_update = True
  359. # the above line ensures that peers check for *strict* synchronization once per epoch
  360. if not self.client_mode:
  361. self.state_averager.state_sharing_priority = self.local_epoch
  362. if self.use_gradient_averaging and not self.auxiliary:
  363. self.grad_averager.reset_accumulated_grads_()
  364. if not self.client_mode:
  365. self.grad_averager.state_sharing_priority = self.local_epoch
  366. logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
  367. def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
  368. """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
  369. if grad_scaler is not None:
  370. with grad_scaler.running_global_step():
  371. assert grad_scaler.unscale_(self)
  372. if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
  373. logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_grads}")
  374. self.scheduled_grads = None
  375. began_averaging_gradients = False
  376. if self.tracker.global_progress.num_peers > 1:
  377. try:
  378. self.scheduled_grads = self.grad_averager.step(
  379. control=self.scheduled_grads, reset_accumulators=True, wait=False
  380. )
  381. assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
  382. began_averaging_gradients = True
  383. except BaseException as e:
  384. logger.exception(e)
  385. if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
  386. logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
  387. self.scheduled_grads.cancel()
  388. self.scheduled_grads = None
  389. return began_averaging_gradients
  390. def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
  391. """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
  392. assert not self.use_local_updates and not self.auxiliary
  393. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  394. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  395. self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
  396. self.grad_averager.reset_accumulated_grads_()
  397. return False
  398. self.grad_averager.accumulate_grads_(batch_size)
  399. self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
  400. return True
  401. def _maybe_schedule_gradient_averaging(self) -> None:
  402. """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
  403. assert self.use_gradient_averaging
  404. if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
  405. if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
  406. if self.delay_grad_averaging:
  407. # wait for previous averaging to finish before starting a new one
  408. self.state_averager.step(wait_for_delayed_updates=True)
  409. eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
  410. eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
  411. logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
  412. scheduled_time = self.tracker.estimated_next_update_time
  413. if self.client_mode:
  414. scheduled_time = get_dht_time() + self.averaging_timeout
  415. self.scheduled_grads = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
  416. def _maybe_schedule_state_averaging(self) -> None:
  417. """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
  418. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  419. if next_epoch % self.average_state_every != 0:
  420. return # averaging is not performed at this epoch
  421. estimated_time = self.tracker.estimated_next_update_time
  422. estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
  423. estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
  424. eta_seconds_to_averaging = estimated_time - get_dht_time()
  425. if eta_seconds_to_averaging <= self.matchmaking_time:
  426. if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
  427. min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
  428. actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
  429. logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
  430. if self.client_mode:
  431. estimated_time = get_dht_time() + self.averaging_timeout
  432. self.scheduled_state = self.state_averager.schedule_step(
  433. estimated_time, gather=next_epoch, timeout=self.averaging_timeout
  434. )
  435. def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
  436. """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
  437. assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
  438. averaged_gradients = False
  439. try:
  440. if maybe_step_control is not None:
  441. group_info = maybe_step_control.result(self.averaging_timeout)
  442. logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
  443. self._load_averaged_gradients_into_optimizer_()
  444. averaged_gradients = True
  445. else:
  446. logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
  447. except BaseException as e:
  448. logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
  449. if not averaged_gradients:
  450. logger.log(self.status_loglevel, f"Proceeding with local gradients")
  451. self.grad_averager.load_accumulators_into_averager_()
  452. self._load_averaged_gradients_into_optimizer_()
  453. def _load_averaged_gradients_into_optimizer_(self):
  454. """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
  455. assert self.use_gradient_averaging
  456. if self.offload_optimizer:
  457. pass # averaged gradients are already baked into optimizer, see _make_gradient_averager
  458. else:
  459. # copy averaged gradients into optimizer .grad buffers
  460. optimized_param_groups = self.state_averager.optimizer.param_groups
  461. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  462. with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
  463. assert len(averaged_gradients) == len(optimized_parameters)
  464. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  465. opt_param.grad.copy_(averaged_grad, non_blocking=True)
  466. self.grad_averager.notify_used_averaged_gradients()
  467. def zero_grad(self, set_to_none: bool = False):
  468. """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
  469. if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
  470. raise ValueError(
  471. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  472. f"call zero_grad manually. Gradients will be refreshed internally."
  473. )
  474. for param_group in self.param_groups:
  475. for param in param_group["params"]:
  476. if param.grad is None:
  477. pass
  478. elif set_to_none:
  479. param.grad = None
  480. else:
  481. param.grad.zero_()
  482. def should_load_state_from_peers(self) -> bool:
  483. """
  484. If true, peer will discard local progress and attempt to download state from peers.
  485. This method allows peer to continue training in two cases:
  486. - peer is on the same epoch as other collaborators - keep training normally
  487. - peer was on the same epoch and accumulated some grads, but some collaborators
  488. have just transitioned to the next epoch - this peer should also transition.
  489. :note: The latter case occurs due to the lack of network synchrony: the first peer that
  490. detects enough samples will transition to the next step and start counting samples anew.
  491. Some other peers may take time before they check with DHT and observe that
  492. - the global epoch is technically one epoch ahead of the current one and
  493. - the remaining (non-transitioned) peers no longer have target_batch_size between them
  494. If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
  495. """
  496. if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
  497. self._should_check_synchronization_on_update = False
  498. return self.local_epoch != self.tracker.global_epoch # require exact synchronization once per step
  499. return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch
  500. def load_state_from_peers(self, **kwargs):
  501. """Attempt to fetch the newest collaboration state from other peers"""
  502. self._finish_scheduled_averaging()
  503. with self.tracker.pause_updates():
  504. while True:
  505. try:
  506. self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  507. break
  508. except KeyboardInterrupt:
  509. raise
  510. except BaseException as e:
  511. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  512. continue
  513. if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
  514. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
  515. self.state_averager.local_epoch = self.tracker.global_epoch
  516. self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
  517. if not self.client_mode:
  518. self.state_averager.state_sharing_priority = self.local_epoch
  519. if self.use_gradient_averaging:
  520. self.grad_averager.reset_accumulated_grads_()
  521. if not self.client_mode:
  522. self.grad_averager.state_sharing_priority = self.local_epoch
  523. def _finish_scheduled_averaging(self):
  524. if self.scheduled_grads is not None:
  525. self.scheduled_grads.weight = 0
  526. self.scheduled_grads.allow_allreduce()
  527. if self.scheduled_state is not None:
  528. self.scheduled_state.weight = 0
  529. self.scheduled_state.allow_allreduce()
  530. if self.scheduled_grads is not None:
  531. try:
  532. self.scheduled_grads.result(timeout=max(0.0, self.scheduled_grads.deadline - get_dht_time()))
  533. except BaseException as e:
  534. logger.warning(self.status_loglevel, f"Caught {e} while averaging gradients")
  535. if not self.scheduled_grads.done():
  536. self.scheduled_grads.cancel()
  537. if self.scheduled_state is not None:
  538. try:
  539. self.scheduled_state.result(timeout=max(0.0, self.scheduled_state.deadline - get_dht_time()))
  540. except BaseException as e:
  541. logger.warning(self.status_loglevel, f"Caught {e} while averaging state")
  542. if not self.scheduled_state.done():
  543. self.scheduled_state.cancel()
  544. def state_dict(self) -> dict:
  545. state_dict = self.state_averager.optimizer.state_dict()
  546. state_dict["state"]["local_epoch"] = self.local_epoch
  547. return state_dict
  548. def load_state_dict(self, state_dict: dict):
  549. if "local_epoch" in state_dict["state"]:
  550. self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
  551. return self.state_averager.optimizer.load_state_dict(state_dict)
  552. @property
  553. def state(self):
  554. return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
  555. @property
  556. def opt(self) -> TorchOptimizer:
  557. return self.state_averager.optimizer
  558. @property
  559. def param_groups(self) -> ParamGroups:
  560. next_index = 0
  561. param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
  562. for param_group in param_groups:
  563. num_params = len(param_group["params"])
  564. main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
  565. param_group["params"] = main_params_for_group
  566. next_index += num_params
  567. assert next_index == len(self.state_averager.main_parameters)
  568. return param_groups
  569. def add_param_group(self, param_group: dict) -> None:
  570. raise ValueError(
  571. f"{self.__class__.__name__} does not support calling add_param_group after creation."
  572. f"Please provide all parameter groups at init."
  573. )
  574. def __repr__(self):
  575. return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
  576. def shutdown(self):
  577. logger.debug("Sending goodbye to peers...")
  578. self._finish_scheduled_averaging()
  579. self.tracker.shutdown(self.shutdown_timeout)
  580. logger.debug("Shutting down averagers...")
  581. self.state_averager.step(wait_for_delayed_updates=True)
  582. self.state_averager.shutdown()
  583. if self.use_gradient_averaging:
  584. self.grad_averager.shutdown()
  585. logger.debug(f"{self.__class__.__name__} is shut down.")
  586. def __del__(self):
  587. if self._parent_pid == os.getpid() and self.is_alive():
  588. self.shutdown()