power_ef_averager.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import asyncio
  2. import contextlib
  3. import faulthandler
  4. import math
  5. import multiprocessing as mp
  6. from typing import Any, Iterable, Optional, Sequence
  7. import numpy as np
  8. import torch
  9. import hivemind
  10. from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
  11. from hivemind.averaging.control import AveragingStage, StepControl
  12. from hivemind.averaging.group_info import GroupInfo
  13. from hivemind.averaging.load_balancing import load_balance_peers
  14. from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
  15. from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
  16. from hivemind.compression import (
  17. CompressionBase,
  18. CompressionInfo,
  19. NoCompression,
  20. deserialize_torch_tensor,
  21. serialize_torch_tensor,
  22. )
  23. from hivemind.dht import DHT, DHTID
  24. from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
  25. from hivemind.proto import averaging_pb2
  26. from hivemind.utils import MPFuture, TensorDescriptor, get_logger
  27. from hivemind.utils.asyncio import (
  28. achain,
  29. aiter_with_timeout,
  30. anext,
  31. as_aiter,
  32. azip,
  33. enter_asynchronously,
  34. switch_to_uvloop,
  35. )
  36. from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
  37. from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
  38. from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
  39. from .grad_averager import GradientAverager
  40. GatheredData = Any
  41. logger = get_logger(__name__)
  42. class PowerEFGradientAverager(GradientAverager):
  43. def __init__(
  44. self,
  45. parameters: Iterable[torch.nn.Parameter],
  46. averager_rank: int,
  47. *,
  48. dht: hivemind.DHT,
  49. prefix: str,
  50. reuse_grad_buffers: bool = False,
  51. accumulate_grads_on: Optional[torch.device] = None,
  52. client_mode: bool = None,
  53. warn: bool = True,
  54. min_comprasion_ratio: float = 0.5,
  55. averaged_grads: Optional[Sequence[torch.Tensor]] = None,
  56. **kwargs,
  57. ):
  58. self.rank = averager_rank
  59. self.parameters = tuple(parameters)
  60. self._uncompressed_gradients = set(
  61. i
  62. for i, grad in enumerate(self._grads_from_parameters())
  63. if len(tuple(grad.size())) == 1
  64. or (
  65. self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
  66. )
  67. )
  68. self._gradient_residual = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
  69. self._qs = list(
  70. torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
  71. for idx, grad in enumerate(self._grads_from_parameters())
  72. if idx not in self._uncompressed_gradients
  73. )
  74. for tensor in self._qs + self._gradient_residual:
  75. if tensor is not None:
  76. assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
  77. tensor.share_memory_()
  78. self.all_reduce_phases = (b".phase1", b".phase2")
  79. super().__init__(
  80. self.parameters,
  81. dht=dht,
  82. prefix=prefix,
  83. reuse_grad_buffers=reuse_grad_buffers,
  84. accumulate_grads_on=accumulate_grads_on,
  85. client_mode=client_mode,
  86. warn=warn,
  87. averaged_grads=averaged_grads,
  88. **kwargs,
  89. )
  90. @contextlib.contextmanager
  91. def _register_allreduce_group(self, group_info: GroupInfo):
  92. """registers a given all-reduce runner to listen for incoming connections"""
  93. try:
  94. for phase in self.all_reduce_phases:
  95. self._running_groups[group_info.group_id + phase] = asyncio.Future()
  96. self._pending_groups_registered.set()
  97. yield
  98. finally:
  99. for phase in self.all_reduce_phases:
  100. maybe_future = self._running_groups.pop(group_info.group_id + phase, None)
  101. if maybe_future and not maybe_future.done():
  102. logger.warning(f"All-reduce group {group_info.group_id + phase} did not finish.")
  103. self._pending_groups_registered.set()
  104. async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  105. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  106. try:
  107. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  108. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  109. modes = tuple(map(AveragingMode, mode_ids))
  110. download_bandwidths = [
  111. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  112. ]
  113. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  114. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  115. )
  116. async with enter_asynchronously(self.get_tensors()) as averaged_grads:
  117. cs = [rest for idx, rest in enumerate(self._gradient_residual) if idx not in self._uncompressed_gradients]
  118. ps = [
  119. torch.zeros((grad.size(0), self.rank), device="cpu")
  120. for idx, grad in enumerate(averaged_grads)
  121. if idx not in self._uncompressed_gradients
  122. ]
  123. for p, q, rest in zip(ps, self._qs, cs):
  124. torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
  125. first_all_reduced = ps + [
  126. rest for idx, rest in enumerate(self._gradient_residual) if idx in self._uncompressed_gradients
  127. ]
  128. allreduce1 = AllReduceRunner(
  129. p2p=self._p2p,
  130. servicer_type=type(self),
  131. prefix=self.prefix,
  132. group_id=group_info.group_id + self.all_reduce_phases[0],
  133. tensors=first_all_reduced,
  134. ordered_peer_ids=group_info.peer_ids,
  135. peer_fractions=peer_fractions,
  136. gathered=user_gathered,
  137. modes=modes,
  138. **kwargs,
  139. )
  140. self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce1)
  141. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  142. async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
  143. # all-reduce is performed asynchronously while iterating
  144. tensor.add_(update, alpha=self._averaging_alpha)
  145. else:
  146. async for _ in allreduce1: # trigger all-reduce by iterating
  147. raise ValueError("aux peers should not receive averaged tensors")
  148. # orth ps
  149. for p in ps:
  150. orthogonalize(p)
  151. # compute qs
  152. for p, q, c in zip(ps, self._qs, cs):
  153. torch.matmul(c.reshape(-1, q.size(0)).t(), p, out=q)
  154. allreduce2 = AllReduceRunner(
  155. p2p=self._p2p,
  156. servicer_type=type(self),
  157. prefix=self.prefix,
  158. group_id=group_info.group_id + self.all_reduce_phases[1],
  159. tensors=self._qs,
  160. ordered_peer_ids=group_info.peer_ids,
  161. peer_fractions=peer_fractions,
  162. gathered=user_gathered,
  163. modes=modes,
  164. **kwargs,
  165. )
  166. self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce2)
  167. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  168. async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
  169. # all-reduce is performed asynchronously while iterating
  170. tensor.add_(update, alpha=self._averaging_alpha)
  171. self.last_updated = get_dht_time()
  172. self._state_updated.set()
  173. else:
  174. async for _ in allreduce2: # trigger all-reduce by iterating
  175. raise ValueError("aux peers should not receive averaged tensors")
  176. # recompute grads
  177. for p, q, c in zip(ps, self._qs, cs):
  178. new_c = torch.matmul(p, q.t())
  179. c.copy_(new_c.reshape(c.size()))
  180. for rest, grad in zip(self._gradient_residual, averaged_grads):
  181. torch.add(grad, rest, out=grad)
  182. return allreduce1.gathered
  183. except BaseException as e:
  184. logger.exception(e)
  185. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  186. finally:
  187. pass
  188. @torch.no_grad()
  189. def load_accumulators_into_averager_(self):
  190. """load locally accumulated gradients into the averager for aggregation"""
  191. # divide locally accumulated gradients by the number of times they were accumulated
  192. grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
  193. with self.get_tensors() as averaged_grads:
  194. for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_residual):
  195. rest.copy_(grad_acc, non_blocking=False).mul_(grad_scale).sub_(averaged_grad)
  196. @torch.jit.script
  197. def orthogonalize(matrix, eps=torch.tensor(1e-8)):
  198. n, m = matrix.shape
  199. for i in range(m):
  200. col = matrix[:, i : i + 1]
  201. col /= torch.sqrt(torch.sum(col ** 2)) + eps
  202. if i + 1 < m:
  203. rest = matrix[:, i + 1 :]
  204. rest -= torch.sum(col * rest, dim=0) * col