power_sgd_averager.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import asyncio
  2. import contextlib
  3. import multiprocessing as mp
  4. from typing import Any, Iterable, Optional, Sequence
  5. import torch
  6. from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
  7. from hivemind.averaging.group_info import GroupInfo
  8. from hivemind.averaging.load_balancing import load_balance_peers
  9. from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
  10. from hivemind.compression import CompressionInfo, TensorRole
  11. from hivemind.dht import DHT
  12. from hivemind.optim.grad_averager import GradientAverager
  13. from hivemind.utils import get_logger
  14. from hivemind.utils.asyncio import as_aiter, azip, enter_asynchronously
  15. from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
  16. from hivemind.utils.timed_storage import get_dht_time
  17. GatheredData = Any
  18. logger = get_logger(__name__)
  19. class PowerSGDGradientAverager(GradientAverager):
  20. """
  21. A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
  22. For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
  23. Put simply, this method approximates large gradient tensors (m,n) with a product of two
  24. smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
  25. As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
  26. High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
  27. Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
  28. To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
  29. if some part of the gradient is "lost in compression", it will be added to the next iteration.
  30. This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
  31. and (b) if devices stay alive only for one step, training with small rank may converge slower.
  32. This is because error feedback takes multiple step to kick in.
  33. Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
  34. If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
  35. normally, i.e. without powerSGD. See min_compression_ratio for details.
  36. :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
  37. matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
  38. :param parameters: pytorch parameters for which to aggregate gradients
  39. :param averager_rank: compress gradient tensors
  40. :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
  41. :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
  42. :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
  43. :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
  44. This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
  45. :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
  46. device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
  47. the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
  48. :param client_mode: if False, this averager will accept incoming requests from other peers.
  49. if True, the averager will only join existing groups where at least one peer has client_mode=False.
  50. By default, this flag is copied from DHTNode inside the ``dht`` instance.
  51. """
  52. def __init__(
  53. self,
  54. parameters: Iterable[torch.nn.Parameter],
  55. averager_rank: int,
  56. *,
  57. dht: DHT,
  58. prefix: str,
  59. reuse_grad_buffers: bool = False,
  60. accumulate_grads_on: Optional[torch.device] = None,
  61. client_mode: bool = None,
  62. warn: bool = True,
  63. min_compression_ratio: float = 0.5,
  64. averaged_grads: Optional[Sequence[torch.Tensor]] = None,
  65. **kwargs,
  66. ):
  67. self.rank = averager_rank
  68. self.parameters = tuple(parameters)
  69. self._uncompressed_gradients_indexes = set(
  70. i
  71. for i, grad in enumerate(self._grads_from_parameters())
  72. if len(tuple(grad.size())) <= 1
  73. or (
  74. 1
  75. - self.rank
  76. * sum(get_flatten_greedy_dims(grad))
  77. / (get_flatten_greedy_dims(grad)[0] * get_flatten_greedy_dims(grad)[1])
  78. < min_compression_ratio
  79. ) # compute how much parameters can we left via factorization
  80. )
  81. self._ms = [
  82. torch.zeros_like(grad, device="cpu").share_memory_()
  83. for idx, grad in enumerate(self._grads_from_parameters())
  84. if idx not in self._uncompressed_gradients_indexes
  85. ]
  86. self._qs = [
  87. torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
  88. for idx, grad in enumerate(self._grads_from_parameters())
  89. if idx not in self._uncompressed_gradients_indexes
  90. ]
  91. self.all_reduce_phases = (b".phase_p", b".phase_q")
  92. super().__init__(
  93. self.parameters,
  94. dht=dht,
  95. prefix=prefix,
  96. reuse_grad_buffers=reuse_grad_buffers,
  97. accumulate_grads_on=accumulate_grads_on,
  98. client_mode=client_mode,
  99. warn=warn,
  100. averaged_grads=None,
  101. **kwargs,
  102. )
  103. @contextlib.contextmanager
  104. def _register_allreduce_group(self, group_info: GroupInfo):
  105. """registers a given all-reduce runner to listen for incoming connections"""
  106. try:
  107. for phase in self.all_reduce_phases:
  108. self._running_groups[group_info.group_id + phase] = asyncio.Future()
  109. self._pending_groups_registered.set()
  110. yield
  111. finally:
  112. for phase in self.all_reduce_phases:
  113. maybe_future = self._running_groups.pop(group_info.group_id + phase, None)
  114. if maybe_future and not maybe_future.done():
  115. logger.warning(f"All-reduce group {group_info.group_id + phase} did not finish.")
  116. self._pending_groups_registered.set()
  117. async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  118. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  119. try:
  120. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  121. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  122. modes = tuple(map(AveragingMode, mode_ids))
  123. download_bandwidths = [
  124. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  125. ]
  126. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  127. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  128. )
  129. async with enter_asynchronously(self.get_tensors()) as averaged_grads:
  130. # make this two pairs list for better mapping between m buffers and gradients
  131. averaged_grads_via_sgd = [
  132. grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
  133. ]
  134. for grad, m in zip(averaged_grads_via_sgd, self._ms):
  135. m.add_(grad.to(m.device))
  136. ps = [
  137. torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
  138. for idx, grad in enumerate(averaged_grad_via_sgd)
  139. ]
  140. for p, q, m in zip(ps, self._qs, self._ms):
  141. # we use reshape for all matrixes because sgd works only with 2d tensors
  142. torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
  143. allreduce_p_phase = AllReduceRunner(
  144. p2p=self._p2p,
  145. servicer_type=type(self),
  146. prefix=self.prefix,
  147. group_id=group_info.group_id + self.all_reduce_phases[0],
  148. tensors=ps,
  149. ordered_peer_ids=group_info.peer_ids,
  150. peer_fractions=peer_fractions,
  151. gathered=user_gathered,
  152. modes=modes,
  153. **kwargs,
  154. )
  155. self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce_p_phase)
  156. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  157. async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce_p_phase):
  158. # all-reduce is performed asynchronously while iterating
  159. tensor.add_(update, alpha=self._averaging_alpha)
  160. else:
  161. async for _ in allreduce_p_phase: # trigger all-reduce by iterating
  162. raise ValueError("aux peers should not receive averaged tensors")
  163. for p in ps:
  164. orthogonalize_(p)
  165. for p, q, m in zip(ps, self._qs, self._ms):
  166. torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
  167. averaged_grad_wo_sgd = [
  168. grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
  169. ]
  170. allreduce_q_phase = AllReduceRunner(
  171. p2p=self._p2p,
  172. servicer_type=type(self),
  173. prefix=self.prefix,
  174. group_id=group_info.group_id + self.all_reduce_phases[1],
  175. tensors=self._qs + averaged_grad_wo_sgd,
  176. ordered_peer_ids=group_info.peer_ids,
  177. peer_fractions=peer_fractions,
  178. gathered=user_gathered,
  179. modes=modes,
  180. **kwargs,
  181. )
  182. self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce_q_phase)
  183. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  184. async for tensor, update in azip(as_aiter(*self._qs), allreduce_q_phase):
  185. # all-reduce is performed asynchronously while iterating
  186. tensor.add_(update, alpha=self._averaging_alpha)
  187. self.last_updated = get_dht_time()
  188. self._state_updated.set()
  189. else:
  190. async for _ in allreduce_q_phase: # trigger all-reduce by iterating
  191. raise ValueError("aux peers should not receive averaged tensors")
  192. for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grad_via_sgd):
  193. new_m = torch.matmul(p, q.t()).reshape(m.size())
  194. m.sub_(new_m)
  195. grad.copy_(new_m)
  196. return allreduce1.gathered
  197. except BaseException as e:
  198. logger.exception(e)
  199. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  200. def get_current_state(self):
  201. with torch.no_grad(), self.lock_averaged_tensors:
  202. grad_averager_buffers = [q for q in self._qs]
  203. grad_averager_buffers_infos = [
  204. CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
  205. for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
  206. ]
  207. metadata = dict(group_bits=self.get_group_bits())
  208. return metadata, grad_averager_buffers, grad_averager_buffers_infos
  209. def load_state_from_peers(self, **kwargs):
  210. loaded_state = super().load_state_from_peers(**kwargs)
  211. if loaded_state is None:
  212. return
  213. metadata, flat_tensors = loaded_state
  214. logger.info("Starting loading gradient averager buffers from peers")
  215. if len(flat_tensors) != len(self._qs):
  216. logger.error("Failed to load state from peer, received parameters, extras or metadata")
  217. return
  218. with torch.no_grad(), self.lock_averaged_tensors:
  219. for local_q, loaded_q in zip(self._qs, flat_tensors):
  220. local_q.copy_(loaded_q, non_blocking=True)