allreduce.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import asyncio
  2. from enum import Enum
  3. from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
  4. import torch
  5. from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
  6. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  7. from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
  8. from hivemind.proto import averaging_pb2
  9. from hivemind.utils import get_logger
  10. from hivemind.utils.asyncio import (
  11. achain,
  12. aenumerate,
  13. afirst,
  14. amap_in_executor,
  15. anext,
  16. as_aiter,
  17. attach_event_on_finished,
  18. )
  19. # flavour types
  20. GroupID = bytes
  21. logger = get_logger(__name__)
  22. class AveragingMode(Enum):
  23. NODE = 0
  24. CLIENT = 1
  25. AUX = 2
  26. class AllReduceRunner(ServicerBase):
  27. """
  28. An internal class that runs butterfly AllReduce in a predefined group of averagers.
  29. This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without
  30. creating a full DecentralizedAverager.
  31. :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
  32. :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
  33. :param servicer_type: a hivemind.p2p.ServicerBase subclass whose RPC signatures are used
  34. when requesting other peers. Typically, it is DecentralizedAverager, its derivative,
  35. or AllReduceRunner itself (for testing purposes).
  36. :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
  37. :param group_id: unique identifier of this specific all-reduce run
  38. :param tensors: local tensors that should be averaged with groupmates
  39. :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
  40. :param peer_id: your peer_id, must be included in ordered_peer_ids
  41. :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
  42. :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
  43. (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
  44. :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
  45. :param gathered: additional user-defined data collected from this group
  46. :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
  47. :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
  48. non-averaging peers receive results only after they finish sending, which helps them avoid
  49. throughput issues in case of asymmetric high-latency connections (e.g. ACK compression).
  50. """
  51. def __init__(
  52. self,
  53. *,
  54. p2p: P2P,
  55. servicer_type: Type[ServicerBase],
  56. prefix: Optional[str],
  57. group_id: GroupID,
  58. tensors: Sequence[torch.Tensor],
  59. weight: Optional[float] = None,
  60. ordered_peer_ids: Sequence[PeerID],
  61. peer_fractions: Tuple[float, ...],
  62. modes: Optional[Sequence[AveragingMode]] = None,
  63. gathered: Optional[Dict[PeerID, Any]] = None,
  64. **kwargs,
  65. ):
  66. self._p2p = p2p
  67. self.peer_id = p2p.peer_id
  68. assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
  69. if not issubclass(servicer_type, ServicerBase):
  70. raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
  71. self._servicer_type = servicer_type
  72. self._prefix = prefix
  73. modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
  74. assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
  75. assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
  76. for mode, frac in zip(modes, peer_fractions):
  77. assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
  78. self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
  79. self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
  80. if weight is None:
  81. weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
  82. self.weight = weight
  83. self._future = asyncio.Future()
  84. self.sender_peer_ids = []
  85. for peer_id, mode in zip(self.ordered_peer_ids, modes):
  86. if mode != AveragingMode.AUX:
  87. self.sender_peer_ids.append(peer_id)
  88. peer_id_index = self.ordered_peer_ids.index(self.peer_id)
  89. self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
  90. self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
  91. self.tensor_part_reducer = TensorPartReducer(
  92. tuple(part.shape for part in self.parts_for_local_averaging),
  93. len(self.sender_peer_ids),
  94. )
  95. def __repr__(self):
  96. return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})"
  97. def __aiter__(self):
  98. return self.run()
  99. def __contains__(self, peer_id: PeerID):
  100. return peer_id in self.ordered_peer_ids
  101. @property
  102. def group_size(self):
  103. return len(self.ordered_peer_ids)
  104. def _get_peer_stub(self, peer: PeerID) -> StubBase:
  105. return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
  106. def should_delay_results(self, peer_id: PeerID) -> bool:
  107. return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0
  108. async def run(self) -> AsyncIterator[torch.Tensor]:
  109. """Run all-reduce, return differences between averaged and original tensors as they are computed"""
  110. pending_tasks = set()
  111. try:
  112. if len(self.sender_peer_ids) == 0:
  113. logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
  114. self.finalize()
  115. elif self.peer_id in self.sender_peer_ids:
  116. for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer):
  117. if parts != 0:
  118. pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id)))
  119. async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
  120. yield averaged_tensor_delta # delta = averaged_tensor - original_tensor
  121. self.finalize()
  122. else: # auxiliary peer
  123. await self.tensor_part_reducer.finished.wait()
  124. self.finalize()
  125. except BaseException as e:
  126. self.finalize(exception=e)
  127. for task in pending_tasks:
  128. task.cancel()
  129. raise
  130. async def _communicate_with_peer(self, peer_id: PeerID):
  131. """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
  132. peer_index = self.ordered_peer_ids.index(peer_id)
  133. if peer_id == self.peer_id:
  134. sender_index = self.sender_peer_ids.index(peer_id)
  135. for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
  136. averaged_part = await self.tensor_part_reducer.accumulate_part(
  137. sender_index, part_index, tensor_part, weight=self.weight
  138. )
  139. self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
  140. else:
  141. code = None
  142. stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
  143. async for part_index, (averaged_part_delta, msg) in aenumerate(
  144. amap_in_executor(
  145. lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
  146. stream,
  147. max_prefetch=self.tensor_part_container.prefetch,
  148. )
  149. ):
  150. if code is None:
  151. code = msg.code
  152. self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
  153. if code != averaging_pb2.AVERAGED_PART:
  154. raise AllreduceException(
  155. f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
  156. f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
  157. f", allreduce failed"
  158. )
  159. async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
  160. parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
  161. first_part = await anext(parts_aiter)
  162. yield averaging_pb2.AveragingData(
  163. code=averaging_pb2.PART_FOR_AVERAGING,
  164. group_id=self.group_id,
  165. tensor_part=first_part,
  166. weight=self.weight,
  167. )
  168. async for part in parts_aiter:
  169. yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
  170. async def rpc_aggregate_part(
  171. self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
  172. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  173. """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
  174. request: averaging_pb2.AveragingData = await anext(stream)
  175. reason_to_reject = self._check_reasons_to_reject(request)
  176. if reason_to_reject:
  177. yield reason_to_reject
  178. return
  179. elif request.code == averaging_pb2.PART_FOR_AVERAGING:
  180. try:
  181. sender_index = self.sender_peer_ids.index(context.remote_id)
  182. if not self.should_delay_results(context.remote_id):
  183. async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
  184. yield msg
  185. else:
  186. done_receiving = asyncio.Event()
  187. delayed_results = asyncio.Queue()
  188. async def _accumulate_parts():
  189. inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
  190. async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
  191. delayed_results.put_nowait(msg)
  192. delayed_results.put_nowait(None)
  193. accumulate_task = asyncio.create_task(_accumulate_parts())
  194. await done_receiving.wait()
  195. while True:
  196. next_result = await delayed_results.get()
  197. if next_result is None:
  198. break
  199. yield next_result
  200. await accumulate_task
  201. except Exception as e:
  202. self.finalize(exception=e)
  203. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  204. else:
  205. error_code = averaging_pb2.MessageCode.Name(request.code)
  206. logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
  207. self.finalize(exception=AllreduceException(f"Peer {context.remote_id} sent {error_code}"))
  208. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  209. def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
  210. if request.group_id != self.group_id:
  211. return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  212. elif self._future.cancelled():
  213. return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
  214. elif self._future.done():
  215. return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  216. async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
  217. loop = asyncio.get_event_loop()
  218. async for part_index, (tensor_part, weight, part_compression) in aenumerate(
  219. amap_in_executor(
  220. lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
  221. stream,
  222. max_prefetch=self.tensor_part_container.prefetch,
  223. )
  224. ):
  225. averaged_part = await self.tensor_part_reducer.accumulate_part(
  226. sender_index, part_index, tensor_part, weight=weight
  227. )
  228. serialized_delta = await loop.run_in_executor(
  229. None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
  230. )
  231. yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
  232. async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
  233. try:
  234. error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
  235. await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
  236. except Exception as e:
  237. logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
  238. def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
  239. """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
  240. assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
  241. pending_tasks = set()
  242. if cancel or exception:
  243. # propagate error to peers
  244. if cancel or isinstance(exception, asyncio.CancelledError):
  245. code = averaging_pb2.CANCELLED
  246. else:
  247. code = averaging_pb2.INTERNAL_ERROR
  248. logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
  249. for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
  250. if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
  251. pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
  252. if not self._future.done():
  253. if cancel:
  254. logger.debug(f"{self} - cancelled")
  255. self._future.cancel()
  256. elif exception:
  257. logger.debug(f"{self} - caught {exception}")
  258. self._future.set_exception(exception)
  259. else:
  260. logger.debug(f"{self} - finished")
  261. self._future.set_result(None)
  262. self.tensor_part_container.finalize()
  263. self.tensor_part_reducer.finalize()
  264. return pending_tasks
  265. else:
  266. logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
  267. return pending_tasks