allreduce.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import asyncio
  2. from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
  3. from enum import Enum
  4. import torch
  5. from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
  6. from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
  7. from hivemind.utils import get_logger
  8. from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
  9. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  10. from hivemind.proto import averaging_pb2
  11. # flavour types
  12. GroupID = bytes
  13. logger = get_logger(__name__)
  14. class AveragingMode(Enum):
  15. NODE = 0
  16. CLIENT = 1
  17. AUX = 2
  18. class AllReduceRunner(ServicerBase):
  19. """
  20. An internal class that runs butterfly AllReduce in a predefined group of averagers.
  21. This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without
  22. creating a full DecentralizedAverager.
  23. :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
  24. :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
  25. :param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers.
  26. Typically, it is a DecentralizedAverager instance or its derivative.
  27. If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes).
  28. :param group_id: unique identifier of this specific all-reduce run
  29. :param tensors: local tensors that should be averaged with groupmates
  30. :param tensors: local tensors that should be averaged with groupmates
  31. :param endpoint: your endpoint, must be included in ordered_group_endpoints
  32. :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
  33. :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
  34. (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
  35. :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
  36. :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
  37. :param gathered: additional user-defined data collected from this group
  38. :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
  39. """
  40. def __init__(
  41. self,
  42. *,
  43. p2p: P2P,
  44. servicer: Optional[ServicerBase],
  45. group_id: GroupID,
  46. tensors: Sequence[torch.Tensor],
  47. ordered_group_endpoints: Sequence[Endpoint],
  48. peer_fractions: Tuple[float, ...],
  49. weights: Optional[Sequence[float]] = None,
  50. modes: Optional[Sequence[AveragingMode]] = None,
  51. gathered: Optional[Dict[Endpoint, Any]] = None,
  52. **kwargs,
  53. ):
  54. self._p2p = p2p
  55. self.endpoint = p2p.id
  56. assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
  57. if servicer is None:
  58. servicer = self
  59. self._servicer = servicer
  60. modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
  61. weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
  62. assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
  63. assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
  64. for mode, frac, weight in zip(modes, peer_fractions, weights):
  65. assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
  66. assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
  67. self.group_id, self.ordered_group_endpoints = group_id, ordered_group_endpoints
  68. self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
  69. self._future = asyncio.Future()
  70. self.sender_endpoints, self.sender_weights = [], []
  71. for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
  72. if mode != AveragingMode.AUX:
  73. self.sender_endpoints.append(endpoint)
  74. self.sender_weights.append(weight)
  75. endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
  76. self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
  77. self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
  78. self.tensor_part_reducer = TensorPartReducer(
  79. tuple(part.shape for part in self.parts_for_local_averaging),
  80. len(self.sender_endpoints),
  81. self.sender_weights,
  82. )
  83. def __repr__(self):
  84. return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
  85. def __aiter__(self):
  86. return self.run()
  87. def __contains__(self, endpoint: Endpoint):
  88. return endpoint in self.ordered_group_endpoints
  89. @property
  90. def group_size(self):
  91. return len(self.ordered_group_endpoints)
  92. def _get_stub(self, peer: Endpoint) -> StubBase:
  93. return self._servicer.get_stub(self._p2p, peer)
  94. async def run(self) -> AsyncIterator[torch.Tensor]:
  95. """Run all-reduce, return differences between averaged and original tensors as they are computed"""
  96. pending_tasks = set()
  97. try:
  98. if len(self.sender_endpoints) == 0:
  99. logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
  100. self.finalize()
  101. elif self.endpoint in self.sender_endpoints:
  102. for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
  103. if parts != 0:
  104. pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
  105. async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
  106. yield averaged_tensor_delta # delta = averaged_tensor - original_tensor
  107. self.finalize()
  108. else: # auxiliary peer
  109. await self.tensor_part_reducer.finished.wait()
  110. self.finalize()
  111. except BaseException as e:
  112. self.finalize(exception=e)
  113. for task in pending_tasks:
  114. task.cancel()
  115. raise
  116. async def _communicate_with_peer(self, peer_endpoint: Endpoint):
  117. """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
  118. peer_index = self.ordered_group_endpoints.index(peer_endpoint)
  119. if peer_endpoint == self.endpoint:
  120. sender_index = self.sender_endpoints.index(peer_endpoint)
  121. for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
  122. averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
  123. self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
  124. else:
  125. loop = asyncio.get_event_loop()
  126. code = None
  127. stream = self._get_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
  128. async for part_index, msg in aenumerate(stream):
  129. if code is None:
  130. code = msg.code
  131. averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
  132. self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
  133. if code != averaging_pb2.AVERAGED_PART:
  134. raise AllreduceException(
  135. f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
  136. f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
  137. f", allreduce failed"
  138. )
  139. async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
  140. parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
  141. first_part = await anext(parts_aiter)
  142. yield averaging_pb2.AveragingData(
  143. code=averaging_pb2.PART_FOR_AVERAGING,
  144. group_id=self.group_id,
  145. endpoint=self.endpoint.to_base58(),
  146. tensor_part=first_part,
  147. )
  148. async for part in parts_aiter:
  149. yield averaging_pb2.AveragingData(tensor_part=part)
  150. async def rpc_aggregate_part(
  151. self, stream: AsyncIterator[averaging_pb2.AveragingData], _: P2PContext
  152. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  153. """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
  154. request: averaging_pb2.AveragingData = await anext(stream)
  155. reason_to_reject = self._check_reasons_to_reject(request)
  156. if reason_to_reject:
  157. yield reason_to_reject
  158. return
  159. elif request.code == averaging_pb2.PART_FOR_AVERAGING:
  160. try:
  161. sender_index = self.sender_endpoints.index(Endpoint.from_base58(request.endpoint))
  162. async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
  163. yield msg
  164. except Exception as e:
  165. self.finalize(exception=e)
  166. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  167. else:
  168. error_code = averaging_pb2.MessageCode.Name(request.code)
  169. logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
  170. self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
  171. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  172. def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
  173. if request.group_id != self.group_id:
  174. return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  175. elif self._future.cancelled():
  176. return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
  177. elif self._future.done():
  178. return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  179. async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
  180. loop = asyncio.get_event_loop()
  181. async for part_index, (tensor_part, part_compression) in aenumerate(
  182. amap_in_executor(
  183. lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
  184. stream,
  185. max_prefetch=self.tensor_part_container.prefetch,
  186. )
  187. ):
  188. averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
  189. serialized_delta = await loop.run_in_executor(
  190. None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
  191. )
  192. yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
  193. async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
  194. error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
  195. async for _ in self._get_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
  196. pass
  197. def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
  198. """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
  199. assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
  200. pending_tasks = set()
  201. if cancel or exception:
  202. # propagate error to peers
  203. if cancel or isinstance(exception, asyncio.CancelledError):
  204. code = averaging_pb2.CANCELLED
  205. else:
  206. code = averaging_pb2.INTERNAL_ERROR
  207. logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
  208. for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
  209. if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
  210. pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
  211. if not self._future.done():
  212. if cancel:
  213. logger.debug(f"{self} - cancelled")
  214. self._future.cancel()
  215. elif exception:
  216. logger.debug(f"{self} - caught {exception}")
  217. self._future.set_exception(exception)
  218. else:
  219. logger.debug(f"{self} - finished")
  220. self._future.set_result(None)
  221. self.tensor_part_container.finalize()
  222. self.tensor_part_reducer.finalize()
  223. return pending_tasks
  224. else:
  225. logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
  226. return pending_tasks