optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. from typing import Callable, Optional, Union
  5. import torch
  6. from hivemind.averaging.control import StepControl
  7. from hivemind.dht import DHT
  8. from hivemind.optim.experimental.grad_averager import GradientAverager
  9. from hivemind.optim.experimental.progress_tracker import ProgressTracker
  10. from hivemind.optim.experimental.state_averager import (
  11. LRSchedulerBase,
  12. OptimizerFactory,
  13. Parameters,
  14. ParamGroups,
  15. SchedulerFactory,
  16. TorchOptimizer,
  17. TrainingStateAverager,
  18. )
  19. from hivemind.optim.grad_scaler import GradScaler
  20. from hivemind.utils import get_dht_time, get_logger
  21. logger = get_logger(__name__)
  22. class Optimizer(torch.optim.Optimizer):
  23. """
  24. Hivemind Optimizer wraps your regular PyTorch Optimizer for training in a swarm of peers. It can be configured with
  25. synchronous, delayed or asynchronous updates to trade between optimization guarantees and compute utilization.
  26. The Optimizer is meant as a drop-in replacement for your regular PyTorch code:
  27. >>> model = transformers.AutoModel("albert-xxlarge-v2")
  28. >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
  29. >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
  30. >>> target_batch_size=4096, batch_size_per_step=4)
  31. >>> while True:
  32. >>> loss = compute_loss_on_batch(model, batch_size=4)
  33. >>> opt.zero_grad()
  34. >>> loss.backward()
  35. >>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
  36. However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
  37. - accumulate a minibatch of data towards the (global) target batch size without changing parameters (yet),
  38. - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step,
  39. - if, for any reason, your peer lags behind the rest of the swarm, it will load state from up-to-date peers.
  40. :note: Hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
  41. learning rate schedulers, curriculum and other time-dependent features should use opt.global_step (and not the
  42. number of local forward-backward cycles). This is because any device can join midway through training, when
  43. other peers have already made some progress and changed their learning rate accordingly.
  44. :param dht: a running hivemind.DHT instance connected to other peers
  45. :param prefix: a unique name of this experiment, used as a common prefix for all DHT keys
  46. :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
  47. :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
  48. :param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
  49. :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
  50. :param scheduler: if specified, use this scheduler to update optimizer learning rate
  51. :note: If you are using ColloptaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
  52. explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
  53. :param matchmaking_time: when looking for group, wait for peers to join for up to this many secodns
  54. :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
  55. :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
  56. :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
  57. This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
  58. :param epoch_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
  59. :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
  60. :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
  61. :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
  62. :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
  63. :param verbose: if True, report internal events such as accumilating gradients and running background tasks
  64. Internally, hivemind.Optimizer consists of 4 components:
  65. - DHT, a decentralized key-value storage used for coordination across the swarm
  66. - GradientAverager that is responsible for aggregating gradients with peers for global steps (can be disabled)
  67. - TrainingStateAverager holds parameters and optimizer/scheduler statistics, keeping them weakly synchronized
  68. by averaging with peers. It can also download these variable from other peers if your peer is out of sync.
  69. - ProgressTracker that uses DHT to track the global training progress: the number of steps or samples accumulated
  70. """
  71. def __init__(
  72. self,
  73. *,
  74. dht: DHT,
  75. prefix: str,
  76. target_batch_size: int,
  77. batch_size_per_step: Optional[int] = None,
  78. optimizer: Union[TorchOptimizer, OptimizerFactory],
  79. params: Optional[Union[Parameters, ParamGroups]] = None,
  80. scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
  81. matchmaking_time: Optional[float] = 15.0,
  82. averaging_timeout: Optional[float] = 300.0,
  83. load_state_timeout: float = 600.0,
  84. reuse_grad_buffers: bool = False,
  85. epoch_tolerance: int = 1,
  86. delay_optimizer_step: bool = False,
  87. client_mode: bool = None,
  88. averager_opts: Optional[dict] = None,
  89. tracker_opts: Optional[dict] = None,
  90. verbose: bool = False,
  91. ):
  92. self.dht, self.prefix, self.epoch_tolerance = dht, prefix, epoch_tolerance
  93. self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
  94. self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
  95. self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
  96. self.client_mode = client_mode if client_mode is not None else self.dht.client_mode
  97. self.status_loglevel = logging.INFO if verbose else logging.DEBUG
  98. self.scheduled_round: Optional[StepControl] = None
  99. self.state_averager = self._make_state_averager(
  100. optimizer=optimizer, params=params, scheduler=scheduler, **averager_opts or {}
  101. )
  102. self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
  103. self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
  104. self._schema_hash = self._compute_schema_hash()
  105. self._parent_pid = os.getpid()
  106. self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
  107. # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
  108. # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
  109. def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
  110. return TrainingStateAverager(
  111. dht=self.dht,
  112. prefix=f"{self.prefix}_state_averager",
  113. allreduce_timeout=self.averaging_timeout,
  114. status_loglevel=self.status_loglevel,
  115. client_mode=self.client_mode,
  116. offload_optimizer=True,
  117. custom_gradients=True,
  118. start=True,
  119. **kwargs,
  120. )
  121. def _make_gradient_averager(self, **kwargs) -> GradientAverager:
  122. assert hasattr(self, "state_averager"), "must initialize state averager first"
  123. grad_averager = GradientAverager(
  124. dht=self.dht,
  125. prefix=f"{self.prefix}_grad_averager",
  126. parameters=self.state_averager.main_parameters,
  127. allreduce_timeout=self.averaging_timeout,
  128. client_mode=self.client_mode,
  129. start=True,
  130. **kwargs,
  131. )
  132. optimized_param_groups = self.state_averager.optimizer.param_groups
  133. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  134. with grad_averager.get_tensors() as averaged_gradients:
  135. assert len(averaged_gradients) == len(optimized_parameters)
  136. for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
  137. opt_param.grad = averaged_grad
  138. return grad_averager
  139. def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
  140. return ProgressTracker(
  141. dht=self.dht,
  142. prefix=self.prefix,
  143. target_batch_size=target_batch_size,
  144. client_mode=self.client_mode,
  145. status_loglevel=self.status_loglevel,
  146. start=True,
  147. **kwargs,
  148. )
  149. def _compute_schema_hash(self) -> int:
  150. optimized_param_groups = self.state_averager.optimizer.param_groups
  151. optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
  152. param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
  153. grad_ids = tuple(id(param.grad) for param in optimized_parameters)
  154. return hash((grad_ids, param_shapes))
  155. def is_alive(self) -> bool:
  156. return self.state_averager.is_alive()
  157. @property
  158. def local_epoch(self) -> int:
  159. return self.state_averager.local_epoch
  160. @property
  161. def should_load_state_from_peers(self) -> bool:
  162. """If true, peer will discard local progress and attempt to download state from peers."""
  163. return self.local_epoch < self.tracker.global_epoch - self.epoch_tolerance
  164. def step(
  165. self,
  166. closure: Optional[Callable[[], torch.Tensor]] = None,
  167. batch_size: Optional[int] = None,
  168. grad_scaler: Optional[GradScaler] = None,
  169. **kwargs,
  170. ):
  171. """
  172. Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
  173. :param closure: A closure that reevaluates the model and returns the loss
  174. :param batch_size: optional override for batch_size_per_step from init
  175. :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
  176. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  177. """
  178. if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
  179. raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler).")
  180. if self.batch_size_per_step is None and batch_size is None:
  181. raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
  182. batch_size = batch_size if batch_size is not None else self.batch_size_per_step
  183. loss = None
  184. if closure is not None:
  185. with torch.enable_grad():
  186. loss = closure()
  187. if self.should_load_state_from_peers:
  188. self.load_state_from_peers()
  189. return loss
  190. if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
  191. logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
  192. self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
  193. self.grad_averager.reset_accumulated_grads_()
  194. return loss
  195. self.grad_averager.accumulate_grads_(batch_size)
  196. self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
  197. self.state_averager.step(apply_delayed_updates=True)
  198. if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
  199. if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
  200. eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
  201. eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
  202. logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
  203. scheduled_time = self.tracker.estimated_next_update_time
  204. if self.client_mode:
  205. scheduled_time = get_dht_time() + self.averaging_timeout
  206. self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
  207. if not self.tracker.ready_to_update_epoch:
  208. return loss
  209. with self.tracker.pause_updates():
  210. logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
  211. # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
  212. if grad_scaler is not None:
  213. with grad_scaler.running_global_step():
  214. assert grad_scaler.unscale_(self)
  215. assert grad_scaler.update()
  216. if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
  217. logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
  218. self.scheduled_round = None
  219. need_averaging = self.tracker.global_progress.num_peers > 1
  220. if need_averaging:
  221. try:
  222. group_info = self.grad_averager.step(
  223. control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
  224. )
  225. logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
  226. except BaseException as e:
  227. logger.log(self.status_loglevel, f"Averaging failed with {repr(e)}")
  228. else:
  229. if self.scheduled_round is not None:
  230. self.scheduled_round.cancel()
  231. logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
  232. assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
  233. with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
  234. # note: we do not need to replace because the offloaded optimizer is already using averaged grads
  235. self.state_averager.step(
  236. increment_epoch=True,
  237. optimizer_step=True,
  238. delay_optimizer_step=self.delay_optimizer_step,
  239. grad_scaler=grad_scaler,
  240. averaging_round=need_averaging,
  241. delay_averaging=True,
  242. averaging_opts=dict(
  243. scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
  244. )
  245. if need_averaging
  246. else None,
  247. )
  248. self.grad_averager.reset_accumulated_grads_()
  249. self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
  250. logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
  251. return loss
  252. def step_aux(self, **kwargs):
  253. """
  254. Find and assist other peers in averaging without sending local gradients.
  255. :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
  256. """
  257. raise NotImplementedError("Auxiliary step for hivemind.Optimizer is not implemented yet.")
  258. def zero_grad(self, set_to_none: bool = False):
  259. """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
  260. if self.grad_averager.reuse_grad_buffers:
  261. raise ValueError(
  262. f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
  263. f"call zero_grad manually. Gradients will be refreshed internally."
  264. )
  265. for param in self.grad_averager.parameters:
  266. if param.grad is None:
  267. pass
  268. elif set_to_none:
  269. param.grad = None
  270. else:
  271. param.grad.zero_()
  272. def load_state_from_peers(self, **kwargs):
  273. """Attempt to fetch the newest collaboration state from other peers"""
  274. if self.scheduled_round is not None and not self.scheduled_round.done():
  275. self.scheduled_round.cancel()
  276. with self.tracker.pause_updates():
  277. while True:
  278. try:
  279. self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
  280. break
  281. except KeyboardInterrupt:
  282. raise
  283. except BaseException as e:
  284. logger.exception(f"Failed to load state from peers: {e}, retrying ...")
  285. continue
  286. if self.tracker.global_epoch - self.epoch_tolerance <= self.local_epoch < self.tracker.global_epoch:
  287. logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
  288. self.state_averager.local_epoch = self.tracker.global_epoch
  289. self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
  290. self.grad_averager.reset_accumulated_grads_()
  291. def state_dict(self) -> dict:
  292. state_dict = self.state_averager.optimizer.state_dict()
  293. state_dict["state"]["local_epoch"] = self.local_epoch
  294. return state_dict
  295. def load_state_dict(self, state_dict: dict):
  296. if "local_epoch" in state_dict["state"]:
  297. self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
  298. return self.state_averager.optimizer.load_state_dict(state_dict)
  299. @property
  300. def state(self):
  301. return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
  302. @property
  303. def opt(self) -> TorchOptimizer:
  304. return self.state_averager.optimizer
  305. @property
  306. def param_groups(self) -> ParamGroups:
  307. next_index = 0
  308. param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
  309. for param_group in param_groups:
  310. num_params = len(param_group["params"])
  311. main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
  312. param_group["params"] = main_params_for_group
  313. next_index += num_params
  314. assert next_index == len(self.state_averager.main_parameters)
  315. return param_groups
  316. def add_param_group(self, param_group: dict) -> None:
  317. raise ValueError(
  318. f"{self.__class__.__name__} does not support calling add_param_group after creation."
  319. f"Please provide all parameter groups at init."
  320. )
  321. def __repr__(self):
  322. return f"{self.__class__.__name__}(prefix={self.prefix}, epoch={self.local_epoch})"
  323. def shutdown(self):
  324. logger.debug("Sending goodbye to peers...")
  325. self.tracker.shutdown()
  326. logger.debug("Shutting down averager...")
  327. self.state_averager.step(wait_for_delayed_update=True)
  328. self.state_averager.shutdown()
  329. self.grad_averager.shutdown()
  330. logger.debug(f"{self.__class__.__name__} is shut down.")
  331. def __del__(self):
  332. if self.is_alive() and self._parent_pid == os.getpid():
  333. self.shutdown()