grad_averager.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import contextlib
  2. from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union
  3. import torch
  4. import hivemind
  5. from hivemind.averaging import DecentralizedAverager
  6. from hivemind.averaging.control import StepControl
  7. from hivemind.utils import DHTExpiration, get_dht_time, get_logger
  8. logger = get_logger(__name__)
  9. class GradientAverager(DecentralizedAverager):
  10. """
  11. An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
  12. GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
  13. GradientAverager manages three sets of buffers:
  14. (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
  15. These tensors are typically stored on device and updated by torch autograd
  16. (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
  17. - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
  18. which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
  19. (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
  20. :param parameters: pytorch parameters for which to aggregate gradients
  21. :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
  22. :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
  23. :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
  24. This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
  25. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  26. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  27. the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
  28. :param client_mode: if False, this averager will accept incoming requests from other peers.
  29. if True, the averager will only join existing groups where at least one peer has client_mode=False.
  30. By default, this flag is copied from DHTNode inside the ``dht`` instance.
  31. :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
  32. :param averaged_grads: if provided, it will be used as a set of averagable gradients
  33. :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
  34. Example:
  35. >>> model = SuchModelMuchLayers()
  36. >>> opt = torch.optim.Adam(model.parameters())
  37. >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
  38. >>> next_step_time = hivemind.get_dht_time() + 60 # runs global steps every 60 seconds
  39. >>> next_step_control = None
  40. >>> while True:
  41. >>> # accumulate as many gradients as you can before next_step_time
  42. >>> loss = compute_loss(model, batch_size=32)
  43. >>> loss.backward()
  44. >>> grad_averager.accumulate_grads_(batch_size=32)
  45. >>> # [optional] next step in 5 seconds, start looking for peers in advance
  46. >>> if next_step_time - hivemind.get_dht_time() <= 5
  47. >>> next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
  48. >>> # aggregate gradients and perform optimizer step
  49. >>> if hivemind.get_dht_time() >= next_step_time:
  50. >>> grad_averager.step(control=next_step_control)
  51. >>> with grad_averager.use_averaged_gradients(): # this will fill param.grads with aggregated gradients
  52. >>> opt.step() # update model parameters using averaged gradients
  53. >>> grad_averager.reset_accumulated_grads_() # prepare for next step
  54. >>> next_step_time = hivemind.get_dht_time() + 60
  55. >>> next_step_control = None
  56. """
  57. def __init__(
  58. self,
  59. parameters: Iterable[torch.nn.Parameter],
  60. *,
  61. dht: hivemind.DHT,
  62. prefix: str,
  63. reuse_grad_buffers: bool = False,
  64. accumulate_grads_on: Optional[torch.device] = None,
  65. client_mode: bool = None,
  66. warn: bool = True,
  67. averaged_grads: Sequence[torch.Tensor] = (),
  68. **kwargs,
  69. ):
  70. if reuse_grad_buffers and accumulate_grads_on is not None:
  71. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  72. client_mode = client_mode if client_mode is not None else dht.client_mode
  73. self.parameters = tuple(parameters)
  74. self.reuse_grad_buffers = reuse_grad_buffers
  75. self.warn = warn
  76. self.local_samples_accumulated = 0
  77. self.local_times_accumulated = 0
  78. self._anchor_batch_size = None
  79. self._local_accumulators = None
  80. if not reuse_grad_buffers:
  81. self._local_accumulators = tuple(
  82. torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
  83. )
  84. self._accumulators_used_in_step = False
  85. self._new_averaged_grads = False
  86. with torch.no_grad():
  87. if not averaged_grads:
  88. averaged_grads = tuple(
  89. grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
  90. )
  91. super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
  92. def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
  93. """gradient buffers associated with parameters"""
  94. for param in self.parameters:
  95. if param.grad is None:
  96. param.grad = torch.zeros_like(param)
  97. yield param.grad
  98. @torch.no_grad()
  99. def _grad_accumulators(self) -> Iterator[torch.Tensor]:
  100. """averager-based gradient accumulators"""
  101. assert (self._local_accumulators is None) == self.reuse_grad_buffers
  102. yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
  103. @torch.no_grad()
  104. def accumulate_grads_(self, batch_size: int):
  105. """add current gradients to local grad accumulators (if used)"""
  106. if self._accumulators_used_in_step and self.warn:
  107. logger.warning(
  108. "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
  109. "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)"
  110. )
  111. self._accumulators_used_in_step = False # warn once per round
  112. if self._anchor_batch_size is None:
  113. # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
  114. self._anchor_batch_size = batch_size
  115. self.local_samples_accumulated += batch_size
  116. self.local_times_accumulated += 1
  117. if self.reuse_grad_buffers:
  118. pass # user is responsible for accumulating gradients in .grad buffers
  119. else:
  120. alpha = float(batch_size) / self._anchor_batch_size
  121. for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
  122. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  123. def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
  124. """
  125. Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
  126. :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
  127. :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
  128. :note: setting weight at this stage is not supported, please leave this parameter as None
  129. :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
  130. :note: in the current implementation, each step_control can only be used in one step.
  131. """
  132. assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
  133. return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
  134. def step(
  135. self,
  136. weight: Optional[float] = None,
  137. reset_accumulators: bool = True,
  138. control: Optional[StepControl] = None,
  139. timeout: Optional[float] = None,
  140. wait: bool = True,
  141. **kwargs,
  142. ):
  143. """
  144. Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
  145. :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
  146. :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
  147. :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
  148. :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
  149. :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
  150. """
  151. if control is None:
  152. control = self.schedule_step(timeout=timeout, **kwargs)
  153. elif len(kwargs) > 0:
  154. raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect")
  155. assert not control.triggered, f"This {type(control)} instance was already used"
  156. if self._new_averaged_grads and self.warn:
  157. logger.warning(
  158. "[warn=True] Starting new averaging round, but previous round results were not used. "
  159. "This may be a sign of incorrect optimizer behavior"
  160. )
  161. self.load_accumulators_into_averager_()
  162. self._accumulators_used_in_step = True
  163. self._new_averaged_grads = True
  164. control.weight = self.local_samples_accumulated if weight is None else weight
  165. if reset_accumulators:
  166. self.reset_accumulated_grads_()
  167. control.allow_allreduce()
  168. return control.result(timeout) if wait else control
  169. @torch.no_grad()
  170. def load_accumulators_into_averager_(self):
  171. """load locally accumulated gradients into the averager for aggregation"""
  172. # divide locally accumulated gradients by the number of times they were accumulated
  173. grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
  174. with self.get_tensors() as averaged_grads:
  175. for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
  176. averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
  177. @torch.no_grad()
  178. def reset_accumulated_grads_(self):
  179. """reset averager-internal gradient accumulators and the denominator"""
  180. self._accumulators_used_in_step = False
  181. self.local_samples_accumulated = self.local_times_accumulated = 0
  182. self._anchor_batch_size = None
  183. for grad_buf in self._grad_accumulators():
  184. grad_buf.zero_()
  185. @contextlib.contextmanager
  186. @torch.no_grad()
  187. def use_averaged_gradients(self):
  188. """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
  189. self._new_averaged_grads = False
  190. with self.get_tensors() as averaged_grads:
  191. assert len(averaged_grads) == len(self.parameters)
  192. try:
  193. old_grads = [param.grad for param in self.parameters]
  194. for param, new_grad in zip(self.parameters, averaged_grads):
  195. param.grad = new_grad
  196. yield averaged_grads
  197. finally:
  198. for param, old_grad in zip(self.parameters, old_grads):
  199. param.grad = old_grad
  200. def notify_used_averaged_gradients(self):
  201. """Notify averager that the results of a previous averaging round are accounted for"""
  202. self._new_averaged_grads = False