simple.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import time
  2. from threading import Thread, Lock, Event
  3. from typing import Optional, Sequence, Tuple
  4. import torch
  5. from hivemind.dht import DHT
  6. from hivemind.averaging import TrainingAverager
  7. from hivemind.optim.base import DecentralizedOptimizerBase
  8. from hivemind.utils import get_logger, get_dht_time
  9. logger = get_logger(__name__)
  10. class DecentralizedOptimizer(DecentralizedOptimizerBase):
  11. """
  12. A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
  13. parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
  14. :param opt: a pytorch optimizer configured to update model parameters.
  15. :param dht: a running hivemind DHT daemon connected to other peers
  16. :param average_parameters: whether to average model parameters
  17. :param average_gradients: whether to average gradients
  18. :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
  19. :param averaging_steps_period: performs averaging after this many optimizer steps
  20. :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
  21. many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
  22. :param prefix: all DHT keys that point to optimization metadata will have this prefix
  23. :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
  24. :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
  25. :param kwargs: additional parameters passed to TrainingAverager
  26. :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
  27. necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
  28. :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
  29. """
  30. def __init__(
  31. self,
  32. opt: torch.optim.Optimizer,
  33. dht: DHT,
  34. *,
  35. prefix: str,
  36. target_group_size: int,
  37. average_parameters: bool,
  38. average_gradients: bool,
  39. average_opt_statistics: Sequence[str] = (),
  40. averaging_steps_period: int = 1,
  41. averaging_time_period: float = 0,
  42. timeout: Optional[float] = None,
  43. verbose: bool = False,
  44. **kwargs,
  45. ):
  46. super().__init__(opt, dht)
  47. assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
  48. self.local_step, self.averaging_step_period = 0, averaging_steps_period
  49. self.averager = TrainingAverager(
  50. opt,
  51. average_parameters=average_parameters,
  52. average_gradients=average_gradients,
  53. average_opt_statistics=average_opt_statistics,
  54. dht=dht,
  55. start=True,
  56. prefix=prefix,
  57. target_group_size=target_group_size,
  58. **kwargs,
  59. )
  60. self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
  61. self.background_averaging_thread = Thread(
  62. name=f"{self.__class__.__name__}",
  63. daemon=True,
  64. target=self._average_parameters_in_background,
  65. args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
  66. kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose),
  67. )
  68. self.background_averaging_thread.start()
  69. def step(self, *args, **kwargs):
  70. with self.lock_parameters:
  71. loss = self.opt.step(*args, **kwargs)
  72. self.local_step += 1
  73. if self.local_step % self.averaging_step_period == 0:
  74. self.update_event.set()
  75. return loss
  76. def zero_grad(self, *args, **kwargs):
  77. return self.opt.zero_grad(*args, **kwargs)
  78. def __del__(self):
  79. self.stop_event.set()
  80. self.update_event.set()
  81. def shutdown(self):
  82. self.stop_event.set()
  83. self.update_event.set()
  84. self.averager.shutdown()
  85. @staticmethod
  86. @torch.no_grad()
  87. def _average_parameters_in_background(
  88. lock_parameters: Lock,
  89. update_event: Event,
  90. stop_event: Event,
  91. averager: TrainingAverager,
  92. averaging_period: float,
  93. verbose: bool,
  94. **kwargs,
  95. ):
  96. """Iteratively find groups of peers, average parameters with these peers and update local model parameters."""
  97. while not stop_event.is_set():
  98. update_event.wait()
  99. update_event.clear()
  100. if stop_event.is_set():
  101. break
  102. if averaging_period:
  103. current_time = get_dht_time()
  104. # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
  105. time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
  106. time.sleep(time_to_nearest_interval)
  107. if verbose:
  108. logger.info(f"Starting a new averaging round with current parameters.")
  109. try:
  110. group_info = averager.step(lock_parameters, **kwargs)
  111. if verbose:
  112. if group_info is not None:
  113. logger.info(f"Finished averaging round in with {len(group_info)} peers.")
  114. else:
  115. logger.warning(f"Averaging round failed: could not find group.")
  116. except Exception as e:
  117. logger.error(f"Averaging round failed: caught {e}.")
  118. class DecentralizedSGD(DecentralizedOptimizer):
  119. """
  120. Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
  121. :param dht: a running hivemind DHT daemon connected to other peers
  122. :param prefix: all DHT keys that point to optimization metadata will have this prefix
  123. :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
  124. :param kwargs: additional parameters passed to DecentralizedOptimizer
  125. - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
  126. Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
  127. - [2] Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices
  128. https://arxiv.org/abs/2103.03239
  129. """
  130. def __init__(
  131. self,
  132. params,
  133. lr: float,
  134. *,
  135. dht: DHT,
  136. prefix: str,
  137. target_group_size: int,
  138. momentum: float = 0,
  139. dampening: float = 0,
  140. weight_decay: float = 0,
  141. nesterov: bool = False,
  142. **kwargs,
  143. ):
  144. opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
  145. super().__init__(
  146. opt,
  147. dht,
  148. prefix=prefix,
  149. target_group_size=target_group_size,
  150. average_parameters=True,
  151. average_gradients=False,
  152. **kwargs,
  153. )
  154. class DecentralizedAdam(DecentralizedOptimizer):
  155. """
  156. Decentralized Adam/AmsGrad as proposed in [1], [2]
  157. :param dht: a running hivemind DHT daemon connected to other peers
  158. :param prefix: all DHT keys that point to optimization metadata will have this prefix
  159. :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
  160. :param averaging_steps_period: performs averaging after this many optimizer steps
  161. :param kwargs: additional parameters passed to DecentralizedOptimizer
  162. - [1] On the Convergence of Decentralized Adaptive Gradient Methods
  163. - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
  164. """
  165. def __init__(
  166. self,
  167. params,
  168. lr: float,
  169. *,
  170. dht: DHT,
  171. prefix: str,
  172. target_group_size: int,
  173. averaging_steps_period: int,
  174. betas: Tuple[float, float] = (0.9, 0.999),
  175. eps: float = 1e-8,
  176. weight_decay: float = 0,
  177. amsgrad: bool = False,
  178. **kwargs,
  179. ):
  180. opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
  181. opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
  182. super().__init__(
  183. opt,
  184. dht,
  185. prefix=prefix,
  186. target_group_size=target_group_size,
  187. average_parameters=True,
  188. average_gradients=False,
  189. average_opt_statistics=opt_statistics,
  190. averaging_steps_period=averaging_steps_period,
  191. **kwargs,
  192. )