power_ef_averager.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import asyncio
  2. import contextlib
  3. import faulthandler
  4. import math
  5. import torch
  6. import multiprocessing as mp
  7. from typing import Any, Iterable, Optional, Sequence
  8. import hivemind
  9. from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
  10. from hivemind.averaging.control import AveragingStage, StepControl
  11. from hivemind.averaging.group_info import GroupInfo
  12. from hivemind.averaging.load_balancing import load_balance_peers
  13. from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
  14. from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
  15. from hivemind.compression import (
  16. CompressionBase,
  17. CompressionInfo,
  18. NoCompression,
  19. deserialize_torch_tensor,
  20. serialize_torch_tensor,
  21. )
  22. from hivemind.dht import DHT, DHTID
  23. from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
  24. from hivemind.proto import averaging_pb2
  25. from hivemind.utils import MPFuture, TensorDescriptor, get_logger
  26. from hivemind.utils.asyncio import (
  27. achain,
  28. aiter_with_timeout,
  29. anext,
  30. as_aiter,
  31. azip,
  32. enter_asynchronously,
  33. switch_to_uvloop,
  34. )
  35. from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
  36. from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
  37. from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
  38. from .grad_averager import GradientAverager
  39. GatheredData = Any
  40. logger = get_logger(__name__)
  41. class PowerEFGradientAverager(GradientAverager):
  42. def __init__(
  43. self,
  44. parameters: Iterable[torch.nn.Parameter],
  45. averager_rank: int,
  46. *,
  47. dht: hivemind.DHT,
  48. prefix: str,
  49. reuse_grad_buffers: bool = False,
  50. accumulate_grads_on: Optional[torch.device] = None,
  51. client_mode: bool = None,
  52. warn: bool = True,
  53. **kwargs,
  54. ):
  55. self.rank = averager_rank
  56. self.parameters = tuple(parameters)
  57. self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
  58. self._gs = list(
  59. torch.zeros_like(grad, device=accumulate_grads_on)
  60. for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
  61. )
  62. self._qs = list(
  63. torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device=accumulate_grads_on)
  64. for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
  65. )
  66. for tensor in (self._qs + self._gs):
  67. if tensor is not None:
  68. assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
  69. tensor.share_memory_()
  70. super().__init__(
  71. self.parameters,
  72. dht=dht,
  73. prefix=prefix,
  74. reuse_grad_buffers=reuse_grad_buffers,
  75. accumulate_grads_on=accumulate_grads_on,
  76. client_mode=client_mode,
  77. warn=warn,
  78. **kwargs
  79. )
  80. @contextlib.contextmanager
  81. def _register_allreduce_group(self, group_info: GroupInfo):
  82. """registers a given all-reduce runner to listen for incoming connections"""
  83. try:
  84. self._running_groups[group_info.group_id + b'.phase1'] = asyncio.Future()
  85. self._running_groups[group_info.group_id + b'.phase2'] = asyncio.Future()
  86. self._pending_groups_registered.set()
  87. yield
  88. finally:
  89. maybe_future = self._running_groups.pop(group_info.group_id + b'.phase1', None)
  90. if maybe_future and not maybe_future.done():
  91. logger.warning(f"All-reduce group {group_info.group_id + b'.phase1'} did not finish.")
  92. maybe_future = self._running_groups.pop(group_info.group_id + b'.phase2', None)
  93. if maybe_future and not maybe_future.done():
  94. logger.warning(f"All-reduce group {group_info.group_id + b'.phase2'} did not finish.")
  95. self._pending_groups_registered.set()
  96. async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  97. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  98. async def _dump_later():
  99. await asyncio.sleep(15.0)
  100. print([*map(asyncio.Task.print_stack, asyncio.Task.all_tasks())])
  101. # task = asyncio.create_task(_dump_later())
  102. try:
  103. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  104. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  105. modes = tuple(map(AveragingMode, mode_ids))
  106. # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
  107. download_bandwidths = [
  108. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  109. ]
  110. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  111. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  112. )
  113. async with enter_asynchronously(self.get_tensors()) as local_tensors:
  114. compressed_tensors = [lt for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
  115. cs = [torch.zeros_like(grad, device="cpu") for grad in compressed_tensors]
  116. for c, g, cg in zip(cs, self._gs, compressed_tensors):
  117. torch.sub(cg, g, out=c)
  118. ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
  119. for p, q, c in zip(ps, self._qs, cs):
  120. torch.matmul(c.reshape(-1, q.size(0)), q, out=p)
  121. first_all_reduced = ps + [lt for idx, lt in enumerate(local_tensors) if idx in self._uncompressed_gradients]
  122. allreduce1 = AllReduceRunner(
  123. p2p=self._p2p,
  124. servicer_type=type(self),
  125. prefix=self.prefix,
  126. group_id=group_info.group_id + b'.phase1',
  127. tensors=first_all_reduced,
  128. ordered_peer_ids=group_info.peer_ids,
  129. peer_fractions=peer_fractions,
  130. gathered=user_gathered,
  131. modes=modes,
  132. **kwargs,
  133. )
  134. self._running_groups[group_info.group_id + b'.phase1'].set_result(allreduce1)
  135. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  136. async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
  137. # all-reduce is performed asynchronously while iterating
  138. tensor.add_(update, alpha=self._averaging_alpha)
  139. else:
  140. async for _ in allreduce1: # trigger all-reduce by iterating
  141. raise ValueError("aux peers should not receive averaged tensors")
  142. # orth ps
  143. for p in ps:
  144. orthogonalize(p)
  145. # compute qs
  146. for p, q, c in zip(ps, self._qs, cs):
  147. torch.matmul(c.reshape(-1, q.size(0)).t(), p, out=q)
  148. allreduce2 = AllReduceRunner(
  149. p2p=self._p2p,
  150. servicer_type=type(self),
  151. prefix=self.prefix,
  152. group_id=group_info.group_id + b'.phase2',
  153. tensors=self._qs,
  154. ordered_peer_ids=group_info.peer_ids,
  155. peer_fractions=peer_fractions,
  156. gathered=user_gathered,
  157. modes=modes,
  158. **kwargs,
  159. )
  160. self._running_groups[group_info.group_id + b'.phase2'].set_result(allreduce2)
  161. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  162. async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
  163. # all-reduce is performed asynchronously while iterating
  164. tensor.add_(update, alpha=self._averaging_alpha)
  165. self.last_updated = get_dht_time()
  166. self._state_updated.set()
  167. else:
  168. async for _ in allreduce2: # trigger all-reduce by iterating
  169. raise ValueError("aux peers should not receive averaged tensors")
  170. # recompute grads
  171. for p, q, c in zip(ps, self._qs, cs):
  172. new_c = torch.matmul(p, q.t())
  173. c.copy_(new_c.reshape(c.size()))
  174. for c, g in zip(cs, self._gs):
  175. torch.add(g, c * 0.9, out=g)
  176. return allreduce1.gathered
  177. except BaseException as e:
  178. logger.exception(e)
  179. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  180. finally:
  181. pass
  182. @contextlib.contextmanager
  183. @torch.no_grad()
  184. def use_averaged_gradients(self):
  185. self._new_averaged_grads = False
  186. with self.get_tensors() as averaged_grads:
  187. compressed_tensors = [lt for idx, lt in enumerate(averaged_grads) if idx not in self._uncompressed_gradients]
  188. old_averaged = [torch.zeros_like(lt) for lt in compressed_tensors]
  189. for g, cg, oag in zip(self._gs, compressed_tensors, old_averaged):
  190. oag.copy_(cg)
  191. cg.copy_(g)
  192. try:
  193. assert len(averaged_grads) == len(self.parameters)
  194. old_grads = [param.grad for param in self.parameters]
  195. for param, new_grad in zip(self.parameters, averaged_grads):
  196. param.grad = new_grad
  197. yield
  198. finally:
  199. for param, old_grad in zip(self.parameters, old_grads):
  200. param.grad = old_grad
  201. for cg, oag in zip(compressed_tensors, old_averaged):
  202. cg.copy_(oag)
  203. @torch.jit.script
  204. def orthogonalize(matrix, eps=torch.tensor(1e-8)):
  205. n, m = matrix.shape
  206. for i in range(m):
  207. col = matrix[:, i: i + 1]
  208. col /= torch.sqrt(torch.sum(col ** 2)) + eps
  209. if i + 1 < m:
  210. rest = matrix[:, i + 1:]
  211. rest -= torch.sum(col * rest, dim=0) * col