decentralized_optimizers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from dataclasses import dataclass
  2. import time
  3. from threading import Thread, Lock, Event
  4. from typing import Optional, Sequence, Tuple
  5. import torch
  6. from hivemind.dht import DHT
  7. from hivemind.client import TrainingAverager
  8. from hivemind.optim.base import DecentralizedOptimizerBase
  9. from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
  10. logger = get_logger(__name__)
  11. @dataclass(frozen=False)
  12. class DecentralizedState:
  13. max_epoch: int = 0
  14. total_steps: int = 0
  15. class DecentralizedOptimizer(DecentralizedOptimizerBase):
  16. """
  17. A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
  18. parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
  19. :param opt: a pytorch optimizer configured to update model parameters.
  20. :param dht: a running hivemind DHT daemon connected to other peers
  21. :param prefix: all DHT keys that point to optimization metadata will have this prefix
  22. :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
  23. :param average_parameters: whether to average model parameters
  24. :param average_gradients: whether to average gradients
  25. :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch mote than that, we download state
  26. from other peer.
  27. :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
  28. :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
  29. :param scheduler_cls: lambda with opt in argument which returns learning rate scheduler
  30. :param averaging_steps_period: performs averaging after this many optimizer steps
  31. :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
  32. many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
  33. report_progress_expiration
  34. :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
  35. :param verbose: verbose info
  36. :param kwargs: additional parameters passed to TrainingAverager
  37. :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
  38. necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
  39. :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
  40. """
  41. def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
  42. average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
  43. total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
  44. scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
  45. report_progress_expiration: int = 30, timeout: Optional[float] = None,
  46. verbose: bool = False, **kwargs):
  47. super().__init__(opt, dht)
  48. assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
  49. self.local_step, self.averaging_step_period = 0, averaging_steps_period
  50. self.dht = dht
  51. self.averager = TrainingAverager(opt, average_parameters=average_parameters,
  52. average_gradients=average_gradients,
  53. average_opt_statistics=average_opt_statistics,
  54. dht=dht, start=True, prefix=prefix,
  55. target_group_size=target_group_size, **kwargs)
  56. self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
  57. if scheduler_cls:
  58. self.scheduler = scheduler_cls(opt)
  59. self.local_epoch = 0
  60. self.report_progress_expiration = report_progress_expiration
  61. self.max_allowed_epoch_difference = max_allowed_epoch_difference
  62. self.total_steps_in_epoch = total_steps_in_epoch
  63. self.report_progress_key = f"{prefix}.progress"
  64. self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
  65. self.lock_scheduler_params = Lock()
  66. self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
  67. self.fetch_decentralized_state_event.set()
  68. self._fetch_decentralized_state(initial=True)
  69. self.background_averaging_thread = Thread(
  70. name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
  71. args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
  72. kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose))
  73. self.background_averaging_thread.start()
  74. self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
  75. self.background_report_progress.start()
  76. self.background_fetch_decentralized_state = Thread(
  77. name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
  78. self.background_fetch_decentralized_state.start()
  79. def step(self, *args, **kwargs):
  80. if not self.is_synchronized:
  81. logger.warning("Peer is out of sync.")
  82. self.load_states_from_peers(**kwargs)
  83. return
  84. with self.lock_scheduler_params:
  85. if self.local_epoch < self.decentralized_state.max_epoch:
  86. self.local_step = 0
  87. self.local_epoch = self.decentralized_state.max_epoch
  88. if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
  89. self.local_epoch += 1
  90. self.local_step = 0
  91. if self.scheduler:
  92. while self.local_epoch > self.scheduler._step_count:
  93. self.scheduler.step()
  94. with self.lock_parameters:
  95. step_result = self.opt.step(*args, **kwargs)
  96. self.local_step += 1
  97. if self.local_step % self.averaging_step_period == 0:
  98. self.update_event.set()
  99. self.report_progress_event.set()
  100. self.fetch_decentralized_state_event.set()
  101. return step_result
  102. def zero_grad(self, *args, **kwargs):
  103. return self.opt.zero_grad(*args, **kwargs)
  104. def __del__(self):
  105. self.stop_event.set()
  106. self.update_event.set()
  107. def shutdown(self):
  108. self.stop_event.set()
  109. self.update_event.set()
  110. self.averager.shutdown()
  111. def load_states_from_peers(self, **kwargs):
  112. logger.info("Trying to restore state from peers.")
  113. with self.lock_parameters, self.lock_scheduler_params:
  114. self.zero_grad()
  115. self.averager.load_state_from_peers(**kwargs)
  116. self.local_step = 0
  117. self.local_epoch = self.decentralized_state.max_epoch
  118. @staticmethod
  119. @torch.no_grad()
  120. def _average_parameters_in_background(
  121. lock_parameters: Lock, update_event: Event, stop_event: Event, averager: TrainingAverager,
  122. averaging_period: float, verbose: bool, **kwargs):
  123. """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
  124. while not stop_event.is_set():
  125. update_event.wait()
  126. update_event.clear()
  127. if stop_event.is_set():
  128. break
  129. if averaging_period:
  130. current_time = get_dht_time()
  131. # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
  132. time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
  133. time.sleep(time_to_nearest_interval)
  134. if verbose:
  135. logger.info(f"Starting a new averaging round with current parameters.")
  136. try:
  137. group_info = averager.step(lock_parameters, **kwargs)
  138. if verbose:
  139. if group_info is not None:
  140. logger.info(f"Finished averaging round in with {len(group_info)} peers.")
  141. else:
  142. logger.warning(f"Averaging round failed: could not find group.")
  143. except Exception as e:
  144. logger.error(f"Averaging round failed: caught {e}.")
  145. @property
  146. def is_synchronized(self) -> bool:
  147. return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
  148. @torch.no_grad()
  149. def _report_progress(self):
  150. while not self.stop_event.is_set():
  151. self.report_progress_event.wait()
  152. self.report_progress_event.clear()
  153. if self.stop_event.is_set():
  154. break
  155. current_time = get_dht_time()
  156. with self.lock_scheduler_params:
  157. local_state_info = [self.local_step, current_time, self.local_epoch]
  158. self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_state_info,
  159. expiration_time=current_time + self.report_progress_expiration, return_future=False)
  160. @torch.no_grad()
  161. def _fetch_decentralized_state(self, initial: bool = False):
  162. """ Read collaboration state reported by peers """
  163. while not self.stop_event.is_set():
  164. self.fetch_decentralized_state_event.wait()
  165. self.fetch_decentralized_state_event.clear()
  166. if self.stop_event.is_set():
  167. break
  168. response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
  169. if not isinstance(response, dict) or len(response) == 0:
  170. logger.info(f"Found no active peers: {response}")
  171. with self.lock_scheduler_params:
  172. self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
  173. if initial:
  174. break
  175. continue
  176. valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
  177. num_peers = len(valid_peer_states)
  178. with self.lock_scheduler_params:
  179. global_epoch = self.local_epoch
  180. for step, time, epoch in valid_peer_states:
  181. global_epoch = max(global_epoch, epoch)
  182. total_steps = 0
  183. for step, time, epoch in valid_peer_states:
  184. if epoch == global_epoch:
  185. total_steps += step
  186. self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
  187. if initial:
  188. break
  189. class DecentralizedSGD(DecentralizedOptimizer):
  190. """
  191. Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
  192. :param dht: a running hivemind DHT daemon connected to other peers
  193. :param prefix: all DHT keys that point to optimization metadata will have this prefix
  194. :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
  195. :param kwargs: additional parameters passed to DecentralizedOptimizer
  196. - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
  197. Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
  198. - [2] Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices
  199. https://arxiv.org/abs/2103.03239
  200. """
  201. def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int,
  202. momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, **kwargs):
  203. opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
  204. super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
  205. average_gradients=False, **kwargs)
  206. class DecentralizedAdam(DecentralizedOptimizer):
  207. """
  208. Decentralized Adam/AmsGrad as proposed in [1], [2]
  209. :param dht: a running hivemind DHT daemon connected to other peers
  210. :param prefix: all DHT keys that point to optimization metadata will have this prefix
  211. :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
  212. :param averaging_steps_period: performs averaging after this many optimizer steps
  213. :param kwargs: additional parameters passed to DecentralizedOptimizer
  214. - [1] On the Convergence of Decentralized Adaptive Gradient Methods
  215. - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
  216. """
  217. def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_steps_period: int,
  218. betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0,
  219. amsgrad: bool = False, **kwargs):
  220. opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
  221. opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
  222. super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
  223. average_gradients=False, average_opt_statistics=opt_statistics,
  224. averaging_steps_period=averaging_steps_period, **kwargs)