grad_averager.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import contextlib
  2. from typing import Iterable, Iterator, Optional
  3. import torch
  4. import hivemind
  5. from hivemind.averaging import DecentralizedAverager
  6. from hivemind.averaging.control import StepControl
  7. from hivemind.utils import DHTExpiration, get_dht_time, get_logger
  8. logger = get_logger(__name__)
  9. class GradientAverager(DecentralizedAverager):
  10. """
  11. An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
  12. GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
  13. GradientAverager manages three sets of buffers:
  14. (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
  15. These tensors are typically stored on device and updated by torch autograd
  16. (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
  17. - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
  18. which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
  19. (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
  20. :param parameters: pytorch parameters for which to aggregate gradients
  21. :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
  22. :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
  23. :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
  24. This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
  25. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  26. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  27. the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
  28. :param client_mode: if False, this averager will accept incoming requests from other peers.
  29. if True, the averager will only join existing groups where at least one peer has client_mode=False.
  30. By default, this flag is copied from DHTNode inside the ``dht`` instance.
  31. :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
  32. :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
  33. Example:
  34. >>> model = SuchModelMuchLayers()
  35. >>> opt = torch.optim.Adam(model.parameters())
  36. >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
  37. >>> next_step_time = hivemind.get_dht_time() + 60 # runs global steps every 60 seconds
  38. >>> next_step_control = None
  39. >>> while True:
  40. >>> # accumulate as many gradients as you can before next_step_time
  41. >>> loss = compute_loss(model, batch_size=32)
  42. >>> loss.backward()
  43. >>> grad_averager.accumulate_grads_(batch_size=32)
  44. >>> # [optional] next step in 5 seconds, start looking for peers in advance
  45. >>> if next_step_time - hivemind.get_dht_time() <= 5
  46. >>> next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
  47. >>> # aggregate gradients and perform optimizer step
  48. >>> if hivemind.get_dht_time() >= next_step_time:
  49. >>> grad_averager.step(control=next_step_control)
  50. >>> with grad_averager.use_averaged_gradients(): # this will fill param.grads with aggregated gradients
  51. >>> opt.step() # update model parameters using averaged gradients
  52. >>> grad_averager.reset_accumulated_grads_() # prepare for next step
  53. >>> next_step_time = hivemind.get_dht_time() + 60
  54. >>> next_step_control = None
  55. """
  56. def __init__(
  57. self,
  58. parameters: Iterable[torch.nn.Parameter],
  59. *,
  60. dht: hivemind.DHT,
  61. prefix: str,
  62. reuse_grad_buffers: bool = False,
  63. accumulate_grads_on: Optional[torch.device] = None,
  64. client_mode: bool = None,
  65. warn: bool = True,
  66. **kwargs,
  67. ):
  68. if reuse_grad_buffers and accumulate_grads_on is not None:
  69. logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
  70. client_mode = client_mode if client_mode is not None else dht.client_mode
  71. self.parameters = tuple(parameters)
  72. self.reuse_grad_buffers = reuse_grad_buffers
  73. self.warn = warn
  74. self.local_samples_accumulated = 0
  75. self.local_times_accumulated = 0
  76. self._anchor_batch_size = None
  77. self._local_accumulators = None
  78. if not reuse_grad_buffers:
  79. self._local_accumulators = tuple(
  80. torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
  81. )
  82. self._accumulators_used_in_step = False
  83. self._new_averaged_grads = False
  84. with torch.no_grad():
  85. averaged_grads = tuple(
  86. grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
  87. )
  88. super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
  89. def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
  90. """gradient buffers associated with parameters"""
  91. for param in self.parameters:
  92. if param.grad is None:
  93. param.grad = torch.zeros_like(param)
  94. yield param.grad
  95. @torch.no_grad()
  96. def _grad_accumulators(self) -> Iterator[torch.Tensor]:
  97. """averager-based gradient accumulators"""
  98. assert (self._local_accumulators is None) == self.reuse_grad_buffers
  99. yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
  100. @torch.no_grad()
  101. def accumulate_grads_(self, batch_size: int):
  102. """add current gradients to local grad accumulators (if used)"""
  103. if self._accumulators_used_in_step and self.warn:
  104. logger.warning(
  105. "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
  106. "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)"
  107. )
  108. self._accumulators_used_in_step = False # warn once per round
  109. if self._anchor_batch_size is None:
  110. # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
  111. self._anchor_batch_size = batch_size
  112. self.local_samples_accumulated += batch_size
  113. self.local_times_accumulated += 1
  114. if self.reuse_grad_buffers:
  115. pass # user is responsible for accumulating gradients in .grad buffers
  116. else:
  117. alpha = float(batch_size) / self._anchor_batch_size
  118. for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
  119. grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
  120. def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
  121. """
  122. Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
  123. :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
  124. :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
  125. :note: setting weight at this stage is not supported, please leave this parameter as None
  126. :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
  127. :note: in the current implementation, each step_control can only be used in one step.
  128. """
  129. assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
  130. return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
  131. def step(
  132. self,
  133. weight: Optional[float] = None,
  134. reset_accumulators: bool = True,
  135. control: Optional[StepControl] = None,
  136. timeout: Optional[float] = None,
  137. wait: bool = True,
  138. **kwargs,
  139. ):
  140. """
  141. Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
  142. :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
  143. :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
  144. :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
  145. :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
  146. :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
  147. """
  148. if control is None:
  149. control = self.schedule_step(timeout=timeout, **kwargs)
  150. elif len(kwargs) > 0:
  151. raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect")
  152. assert not control.triggered, f"This {type(control)} instance was already used"
  153. if self._new_averaged_grads and self.warn:
  154. logger.warning(
  155. "[warn=True] Starting new averaging round, but previous round results were not used. "
  156. "This may be a sign of incorrect optimizer behavior"
  157. )
  158. self.load_accumulators_into_averager_()
  159. self._accumulators_used_in_step = True
  160. self._new_averaged_grads = True
  161. control.weight = self.local_samples_accumulated if weight is None else weight
  162. if reset_accumulators:
  163. self.reset_accumulated_grads_()
  164. control.allow_allreduce()
  165. return control.result(timeout) if wait else control
  166. @torch.no_grad()
  167. def load_accumulators_into_averager_(self):
  168. """load locally accumulated gradients into the averager for aggregation"""
  169. # divide locally accumulated gradients by the number of times they were accumulated
  170. grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
  171. with self.get_tensors() as averaged_grads:
  172. for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
  173. averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
  174. @torch.no_grad()
  175. def reset_accumulated_grads_(self):
  176. """reset averager-internal gradient accumulators and the denominator"""
  177. self._accumulators_used_in_step = False
  178. self.local_samples_accumulated = self.local_times_accumulated = 0
  179. self._anchor_batch_size = None
  180. for grad_buf in self._grad_accumulators():
  181. grad_buf.zero_()
  182. @contextlib.contextmanager
  183. @torch.no_grad()
  184. def use_averaged_gradients(self):
  185. """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
  186. self._new_averaged_grads = False
  187. with self.get_tensors() as averaged_grads:
  188. assert len(averaged_grads) == len(self.parameters)
  189. try:
  190. old_grads = [param.grad for param in self.parameters]
  191. for param, new_grad in zip(self.parameters, averaged_grads):
  192. param.grad = new_grad
  193. yield averaged_grads
  194. finally:
  195. for param, old_grad in zip(self.parameters, old_grads):
  196. param.grad = old_grad
  197. def notify_used_averaged_gradients(self):
  198. """Notify averager that the results of a previous averaging round are accounted for"""
  199. self._new_averaged_grads = False