optimizer.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. from functools import partial
  5. from typing import Callable, Optional, Union
  6. import torch
  7. from hivemind.averaging.control import StepControl
  8. from hivemind.dht import DHT
  9. from hivemind.optim.experimental.grad_averager import GradientAverager
  10. from hivemind.optim.experimental.progress_tracker import ProgressTracker
  11. from hivemind.optim.experimental.state_averager import (
  12. LRSchedulerBase,
  13. OptimizerFactory,
  14. Parameters,
  15. ParamGroups,
  16. SchedulerFactory,
  17. TorchOptimizer,
  18. TrainingStateAverager,
  19. )
  20. from hivemind.optim.grad_scaler import GradScaler
  21. from hivemind.utils import get_dht_time, get_logger
  22. logger = get_logger(__name__)
  23. class Optimizer(torch.optim.Optimizer):
  24. """
  25. Hivemind Optimizer wraps your regular PyTorch Optimizer for training in a swarm of peers. It can be configured with
  26. synchronous, delayed or asynchronous updates to trade between optimization guarantees and compute utilization.
  27. The Optimizer is meant as a drop-in replacement for your regular PyTorch code:
  28. >>> model = transformers.AutoModel("albert-xxlarge-v2")
  29. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
  30. >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
  31. >>> target_batch_size=4096, batch_size_per_step=4)
  32. >>> while True:
  33. >>> loss = compute_loss_on_batch(model, batch_size=4)
  34. >>> opt.zero_grad()
  35. >>> loss.backward()
  36. >>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
  37. However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
  38. - accumulate a minibatch of data towards the (global) target batch size without changing parameters (yet),
  39. - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step,
  40. - if, for any reason, your peer lags behind the rest of the swarm, it will load state from up-to-date peers.
  41. :note: Hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
  42. learning rate schedulers, curriculum and other time-dependent features should use opt.global_step (and not the
  43. number of local forward-backward cycles). This is because any device can join midway through training, when
  44. other peers have already made some progress and changed their learning rate accordingly.
  45. :param dht: a running hivemind.DHT instance connected to other peers
  46. :param prefix: a unique name of this experiment, used as a common prefix for all DHT keys
  47. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  48. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  49. :param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  50. :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
  51. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  52. :note: If you are using ColloptaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
  53. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  54. :param matchmaking_time: when looking for group, wait for peers to join for up to this many secodns
  55. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  56. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
  57. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  58. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  59. :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many epochs
  60. :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
  61. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  62. :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
  63. :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
  64. :param verbose: if True, report internal events such as accumilating gradients and running background tasks
  65. Internally, hivemind.Optimizer consists of 4 components:
  66. - DHT, a decentralized key-value storage used for coordination across the swarm
  67. - GradientAverager that is responsible for aggregating gradients with peers for global steps (can be disabled)
  68. - TrainingStateAverager holds parameters and optimizer/scheduler statistics, keeping them weakly synchronized
  69. by averaging with peers. It can also download these variable from other peers if your peer is out of sync.
  70. - ProgressTracker that uses DHT to track the global training progress: the number of steps or samples accumulated
  71. """
  72. def __init__(
  73. self,
  74. *,
  75. dht: DHT,
  76. prefix: str,
  77. target_batch_size: int,
  78. batch_size_per_step: Optional[int] = None,
  79. optimizer: Union[TorchOptimizer, OptimizerFactory],
  80. params: Optional[Union[Parameters, ParamGroups]] = None,
  81. scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
  82. matchmaking_time: Optional[float] = 15.0,
  83. averaging_timeout: Optional[float] = 300.0,
  84. load_state_timeout: float = 600.0,
  85. average_state_every: int = 1,
  86. reuse_grad_buffers: bool = False,
  87. delay_grad_averaging: bool = False,
  88. delay_optimizer_step: Optional[bool] = None,
  89. client_mode: bool = None,
  90. auxiliary: bool = False,
  91. averager_opts: Optional[dict] = None,
  92. tracker_opts: Optional[dict] = None,
  93. shutdown_timeout: float = 5,
  94. verbose: bool = False,
  95. ):
  96. client_mode = client_mode if client_mode is None else dht.client_mode
  97. delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
  98. assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
  99. assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
  100. assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
  101. self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
  102. self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
  103. self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
  104. self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
  105. self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
  106. self.shutdown_timeout = shutdown_timeout
  107. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  108. self.scheduled_round: Optional[StepControl] = None
  109. self.previous_round: Optional[StepControl] = None
  110. self.state_averager = self._make_state_averager(
  111. optimizer=optimizer, params=params, scheduler=scheduler, **averager_opts or {}
  112. )
  113. self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
  114. self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
  115. self._should_check_synchronization_on_update = True # used in self.should_load_state_from_peers
  116. self._schema_hash = self._compute_schema_hash()
  117. self._parent_pid = os.getpid()
  118. self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
  119. # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
  120. # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
  121. def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
  122. return TrainingStateAverager(
  123. dht=self.dht,
  124. prefix=f"{self.prefix}_state_averager",
  125. allreduce_timeout=self.averaging_timeout,
  126. shutdown_timeout=self.shutdown_timeout,
  127. status_loglevel=self.status_loglevel,
  128. client_mode=self.client_mode,
  129. auxiliary=self.auxiliary,
  130. offload_optimizer=True,
  131. custom_gradients=True,
  132. start=True,
  133. **kwargs,
  134. )
  135. def _make_gradient_averager(self, **kwargs) -> GradientAverager:
  136. assert hasattr(self, "state_averager"), "must initialize state averager first"
  137. grad_averager = GradientAverager(
  138. dht=self.dht,
  139. prefix=f"{self.prefix}_grad_averager",
  140. parameters=self.state_averager.main_parameters,
  141. allreduce_timeout=self.averaging_timeout,
  142. shutdown_timeout=self.shutdown_timeout,
  143. client_mode=self.client_mode,
  144. auxiliary=self.auxiliary,
  145. start=True,
  146. **kwargs,
  147. )
  148. optimized_param_groups = self.state_averager.optimizer.param_groups
  149. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  150. with grad_averager.get_tensors() as averaged_gradients:
  151. assert len(averaged_gradients) == len(optimized_parameters)
  152. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  153. opt_param.grad = averaged_grad
  154. return grad_averager
  155. def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
  156. return ProgressTracker(
  157. dht=self.dht,
  158. prefix=self.prefix,
  159. target_batch_size=target_batch_size,
  160. client_mode=self.client_mode,
  161. status_loglevel=self.status_loglevel,
  162. start=True,
  163. **kwargs,
  164. )
  165. def _compute_schema_hash(self) -> int:
  166. optimized_param_groups = self.state_averager.optimizer.param_groups
  167. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  168. param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
  169. grad_ids = tuple(id(param.grad) for param in optimized_parameters)
  170. return hash((grad_ids, param_shapes))
  171. def is_alive(self) -> bool:
  172. return self.state_averager.is_alive()
  173. @property
  174. def local_epoch(self) -> int:
  175. return self.state_averager.local_epoch
  176. def should_load_state_from_peers(self) -> bool:
  177. """
  178. If true, peer will discard local progress and attempt to download state from peers.
  179. This method allows peer to continue training in two cases:
  180. - peer is on the same epoch as other collaborators - keep training normally
  181. - peer was on the same epoch and accumulated some grads, but some collaborators
  182. have just transitioned to the next epoch - this peer should also transition.
  183. :note: The latter case occurs due to the lack of network synchrony: the first peer that
  184. detects enough samples will transition to the next step and start counting samples anew.
  185. Some other peers may take time before they check with DHT and observe that
  186. - the global epoch is technically one epoch ahead of the current one and
  187. - the remaining (non-transitioned) peers no longer have target_batch_size between them
  188. If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
  189. """
  190. if self._should_check_synchronization_on_update and self.tracker.updated_progress_this_epoch.is_set():
  191. self._should_check_synchronization_on_update = False
  192. return self.local_epoch != self.tracker.global_epoch # require exact synchronization once per step
  193. return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch
  194. def step(
  195. self,
  196. closure: Optional[Callable[[], torch.Tensor]] = None,
  197. batch_size: Optional[int] = None,
  198. grad_scaler: Optional[GradScaler] = None,
  199. ):
  200. """
  201. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  202. :param closure: A closure that reevaluates the model and returns the loss
  203. :param batch_size: optional override for batch_size_per_step from init
  204. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
  205. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  206. """
  207. if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
  208. raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
  209. if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
  210. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  211. if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
  212. raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
  213. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  214. loss = None
  215. if closure is not None:
  216. with torch.enable_grad():
  217. loss = closure()
  218. if not self.auxiliary and self.should_load_state_from_peers():
  219. logger.log(self.status_loglevel, "Peer is out of sync.")
  220. self.load_state_from_peers()
  221. return loss
  222. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  223. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  224. self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
  225. self.grad_averager.reset_accumulated_grads_()
  226. return loss
  227. if not self.auxiliary:
  228. self.grad_averager.accumulate_grads_(batch_size)
  229. self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
  230. self.state_averager.step(apply_delayed_updates=True)
  231. if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
  232. if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
  233. if self.delay_grad_averaging:
  234. # wait for previous averaging to finish before starting a new one
  235. self.state_averager.step(wait_for_delayed_update=True)
  236. eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
  237. eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
  238. logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
  239. scheduled_time = self.tracker.estimated_next_update_time
  240. if self.client_mode:
  241. scheduled_time = get_dht_time() + self.averaging_timeout
  242. self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
  243. if not self.tracker.ready_to_update_epoch:
  244. return loss
  245. assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
  246. with self.tracker.pause_updates():
  247. # note: we do not need to replace grads because we explicitly load grads into the optimizer
  248. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
  249. if grad_scaler is not None:
  250. with grad_scaler.running_global_step():
  251. assert grad_scaler.unscale_(self)
  252. if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
  253. logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
  254. self.scheduled_round = None
  255. swarm_not_empty = self.tracker.global_progress.num_peers > 1
  256. began_averaging_gradients = False
  257. if swarm_not_empty:
  258. try:
  259. self.scheduled_round = self.grad_averager.step(
  260. control=self.scheduled_round, reset_accumulators=True, wait=False
  261. )
  262. assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
  263. began_averaging_gradients = True
  264. except BaseException as e:
  265. logger.exception(e)
  266. if not began_averaging_gradients and self.scheduled_round is not None and not self.scheduled_round.done():
  267. logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
  268. self.scheduled_round.cancel()
  269. self.scheduled_round = None
  270. if not self.delay_grad_averaging:
  271. self._average_gradients_and_load_into_optimizer(self.scheduled_round)
  272. next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
  273. self.state_averager.step(
  274. increment_epoch=True,
  275. optimizer_step=not self.auxiliary,
  276. delay_optimizer_step=self.delay_optimizer_step,
  277. averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
  278. delay_averaging=not self.auxiliary,
  279. grad_scaler=grad_scaler,
  280. wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
  281. if self.delay_grad_averaging
  282. else None,
  283. averaging_opts=dict(
  284. scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
  285. )
  286. if swarm_not_empty and next_epoch % self.average_state_every == 0
  287. else None,
  288. )
  289. if not self.auxiliary:
  290. self.grad_averager.reset_accumulated_grads_()
  291. self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
  292. self._should_check_synchronization_on_update = True
  293. if not self.client_mode:
  294. self.grad_averager.state_sharing_priority = self.local_epoch
  295. self.state_averager.state_sharing_priority = self.local_epoch
  296. logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
  297. return loss
  298. def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
  299. """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
  300. assert maybe_step_control is None or maybe_step_control.triggered
  301. averaged_gradients = False
  302. try:
  303. if maybe_step_control is not None:
  304. group_info = maybe_step_control.result(self.averaging_timeout)
  305. logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
  306. averaged_gradients = True
  307. else:
  308. logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
  309. except BaseException as e:
  310. logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
  311. if not averaged_gradients:
  312. logger.log(self.status_loglevel, f"Proceeding with local gradients")
  313. self.grad_averager.load_accumulators_into_averager_()
  314. self.grad_averager.notify_used_averaged_gradients()
  315. def zero_grad(self, set_to_none: bool = False):
  316. """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
  317. if self.grad_averager.reuse_grad_buffers:
  318. raise ValueError(
  319. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  320. f"call zero_grad manually. Gradients will be refreshed internally."
  321. )
  322. for param in self.grad_averager.parameters:
  323. if param.grad is None:
  324. pass
  325. elif set_to_none:
  326. param.grad = None
  327. else:
  328. param.grad.zero_()
  329. def load_state_from_peers(self, **kwargs):
  330. """Attempt to fetch the newest collaboration state from other peers"""
  331. if self.scheduled_round is not None and not self.scheduled_round.done():
  332. self.scheduled_round.cancel()
  333. with self.tracker.pause_updates():
  334. while True:
  335. try:
  336. self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  337. break
  338. except KeyboardInterrupt:
  339. raise
  340. except BaseException as e:
  341. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  342. continue
  343. if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
  344. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
  345. self.state_averager.local_epoch = self.tracker.global_epoch
  346. self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
  347. self.grad_averager.reset_accumulated_grads_()
  348. if not self.client_mode:
  349. self.grad_averager.state_sharing_priority = self.local_epoch
  350. self.state_averager.state_sharing_priority = self.local_epoch
  351. def state_dict(self) -> dict:
  352. state_dict = self.state_averager.optimizer.state_dict()
  353. state_dict["state"]["local_epoch"] = self.local_epoch
  354. return state_dict
  355. def load_state_dict(self, state_dict: dict):
  356. if "local_epoch" in state_dict["state"]:
  357. self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
  358. return self.state_averager.optimizer.load_state_dict(state_dict)
  359. @property
  360. def state(self):
  361. return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
  362. @property
  363. def opt(self) -> TorchOptimizer:
  364. return self.state_averager.optimizer
  365. @property
  366. def param_groups(self) -> ParamGroups:
  367. next_index = 0
  368. param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
  369. for param_group in param_groups:
  370. num_params = len(param_group["params"])
  371. main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
  372. param_group["params"] = main_params_for_group
  373. next_index += num_params
  374. assert next_index == len(self.state_averager.main_parameters)
  375. return param_groups
  376. def add_param_group(self, param_group: dict) -> None:
  377. raise ValueError(
  378. f"{self.__class__.__name__} does not support calling add_param_group after creation."
  379. f"Please provide all parameter groups at init."
  380. )
  381. def __repr__(self):
  382. return f"{self.__class__.__name__}(prefix={self.prefix}, epoch={self.local_epoch})"
  383. def shutdown(self):
  384. logger.debug("Sending goodbye to peers...")
  385. self.tracker.shutdown(self.shutdown_timeout)
  386. logger.debug("Shutting down averager...")
  387. self.state_averager.step(wait_for_delayed_update=True)
  388. self.state_averager.shutdown()
  389. self.grad_averager.shutdown()
  390. logger.debug(f"{self.__class__.__name__} is shut down.")
  391. def __del__(self):
  392. if self._parent_pid == os.getpid() and self.is_alive():
  393. self.shutdown()