allreduce.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import asyncio
  2. from enum import Enum
  3. from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
  4. import torch
  5. from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
  6. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  7. from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
  8. from hivemind.proto import averaging_pb2
  9. from hivemind.utils import get_logger
  10. from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
  11. # flavour types
  12. GroupID = bytes
  13. logger = get_logger(__name__)
  14. class AveragingMode(Enum):
  15. NODE = 0
  16. CLIENT = 1
  17. AUX = 2
  18. class AllReduceRunner(ServicerBase):
  19. """
  20. An internal class that runs butterfly AllReduce in a predefined group of averagers.
  21. This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without
  22. creating a full DecentralizedAverager.
  23. :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
  24. :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
  25. :param servicer_type: a hivemind.p2p.ServicerBase subclass whose RPC signatures are used
  26. when requesting other peers. Typically, it is DecentralizedAverager, its derivative,
  27. or AllReduceRunner itself (for testing purposes).
  28. :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
  29. :param group_id: unique identifier of this specific all-reduce run
  30. :param tensors: local tensors that should be averaged with groupmates
  31. :param tensors: local tensors that should be averaged with groupmates
  32. :param peer_id: your peer_id, must be included in ordered_peer_ids
  33. :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
  34. :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
  35. (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
  36. :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
  37. :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
  38. :param gathered: additional user-defined data collected from this group
  39. :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
  40. """
  41. def __init__(
  42. self,
  43. *,
  44. p2p: P2P,
  45. servicer_type: Type[ServicerBase],
  46. prefix: Optional[str],
  47. group_id: GroupID,
  48. tensors: Sequence[torch.Tensor],
  49. ordered_peer_ids: Sequence[PeerID],
  50. peer_fractions: Tuple[float, ...],
  51. weights: Optional[Sequence[float]] = None,
  52. modes: Optional[Sequence[AveragingMode]] = None,
  53. gathered: Optional[Dict[PeerID, Any]] = None,
  54. **kwargs,
  55. ):
  56. self._p2p = p2p
  57. self.peer_id = p2p.peer_id
  58. assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
  59. if not issubclass(servicer_type, ServicerBase):
  60. raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
  61. self._servicer_type = servicer_type
  62. self._prefix = prefix
  63. modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
  64. weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
  65. assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
  66. assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
  67. for mode, frac, weight in zip(modes, peer_fractions, weights):
  68. assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
  69. assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
  70. self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
  71. self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
  72. self._future = asyncio.Future()
  73. self.sender_peer_ids, self.sender_weights = [], []
  74. for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
  75. if mode != AveragingMode.AUX:
  76. self.sender_peer_ids.append(peer_id)
  77. self.sender_weights.append(weight)
  78. peer_id_index = self.ordered_peer_ids.index(self.peer_id)
  79. self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
  80. self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
  81. self.tensor_part_reducer = TensorPartReducer(
  82. tuple(part.shape for part in self.parts_for_local_averaging),
  83. len(self.sender_peer_ids),
  84. self.sender_weights,
  85. )
  86. def __repr__(self):
  87. return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})"
  88. def __aiter__(self):
  89. return self.run()
  90. def __contains__(self, peer_id: PeerID):
  91. return peer_id in self.ordered_peer_ids
  92. @property
  93. def group_size(self):
  94. return len(self.ordered_peer_ids)
  95. def _get_peer_stub(self, peer: PeerID) -> StubBase:
  96. return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
  97. async def run(self) -> AsyncIterator[torch.Tensor]:
  98. """Run all-reduce, return differences between averaged and original tensors as they are computed"""
  99. pending_tasks = set()
  100. try:
  101. if len(self.sender_peer_ids) == 0:
  102. logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
  103. self.finalize()
  104. elif self.peer_id in self.sender_peer_ids:
  105. for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer):
  106. if parts != 0:
  107. pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id)))
  108. async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
  109. yield averaged_tensor_delta # delta = averaged_tensor - original_tensor
  110. self.finalize()
  111. else: # auxiliary peer
  112. await self.tensor_part_reducer.finished.wait()
  113. self.finalize()
  114. except BaseException as e:
  115. self.finalize(exception=e)
  116. for task in pending_tasks:
  117. task.cancel()
  118. raise
  119. async def _communicate_with_peer(self, peer_id: PeerID):
  120. """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
  121. peer_index = self.ordered_peer_ids.index(peer_id)
  122. if peer_id == self.peer_id:
  123. sender_index = self.sender_peer_ids.index(peer_id)
  124. for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
  125. averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
  126. self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
  127. else:
  128. code = None
  129. stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
  130. async for part_index, (averaged_part_delta, msg) in aenumerate(
  131. amap_in_executor(
  132. lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
  133. stream,
  134. max_prefetch=self.tensor_part_container.prefetch,
  135. )
  136. ):
  137. if code is None:
  138. code = msg.code
  139. self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
  140. if code != averaging_pb2.AVERAGED_PART:
  141. raise AllreduceException(
  142. f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
  143. f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
  144. f", allreduce failed"
  145. )
  146. async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
  147. parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
  148. first_part = await anext(parts_aiter)
  149. yield averaging_pb2.AveragingData(
  150. code=averaging_pb2.PART_FOR_AVERAGING,
  151. group_id=self.group_id,
  152. tensor_part=first_part,
  153. )
  154. async for part in parts_aiter:
  155. yield averaging_pb2.AveragingData(tensor_part=part)
  156. async def rpc_aggregate_part(
  157. self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
  158. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  159. """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
  160. request: averaging_pb2.AveragingData = await anext(stream)
  161. reason_to_reject = self._check_reasons_to_reject(request)
  162. if reason_to_reject:
  163. yield reason_to_reject
  164. return
  165. elif request.code == averaging_pb2.PART_FOR_AVERAGING:
  166. try:
  167. sender_index = self.sender_peer_ids.index(context.remote_id)
  168. async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
  169. yield msg
  170. except Exception as e:
  171. self.finalize(exception=e)
  172. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  173. else:
  174. error_code = averaging_pb2.MessageCode.Name(request.code)
  175. logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
  176. self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
  177. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  178. def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
  179. if request.group_id != self.group_id:
  180. return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  181. elif self._future.cancelled():
  182. return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
  183. elif self._future.done():
  184. return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  185. async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
  186. loop = asyncio.get_event_loop()
  187. async for part_index, (tensor_part, part_compression) in aenumerate(
  188. amap_in_executor(
  189. lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
  190. stream,
  191. max_prefetch=self.tensor_part_container.prefetch,
  192. )
  193. ):
  194. averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
  195. serialized_delta = await loop.run_in_executor(
  196. None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
  197. )
  198. yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
  199. async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
  200. error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
  201. # Coroutines are lazy, so we take the first item to start the couroutine's execution
  202. await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
  203. def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
  204. """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
  205. assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
  206. pending_tasks = set()
  207. if cancel or exception:
  208. # propagate error to peers
  209. if cancel or isinstance(exception, asyncio.CancelledError):
  210. code = averaging_pb2.CANCELLED
  211. else:
  212. code = averaging_pb2.INTERNAL_ERROR
  213. logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
  214. for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
  215. if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
  216. pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
  217. if not self._future.done():
  218. if cancel:
  219. logger.debug(f"{self} - cancelled")
  220. self._future.cancel()
  221. elif exception:
  222. logger.debug(f"{self} - caught {exception}")
  223. self._future.set_exception(exception)
  224. else:
  225. logger.debug(f"{self} - finished")
  226. self._future.set_result(None)
  227. self.tensor_part_container.finalize()
  228. self.tensor_part_reducer.finalize()
  229. return pending_tasks
  230. else:
  231. logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
  232. return pending_tasks