partition.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. """
  2. Auxiliary data structures for AllReduceRunner
  3. """
  4. import asyncio
  5. from collections import deque
  6. from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
  7. import numpy as np
  8. import torch
  9. from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
  10. from hivemind.proto import runtime_pb2
  11. from hivemind.utils import amap_in_executor, as_aiter, get_logger
  12. T = TypeVar("T")
  13. DEFAULT_PART_SIZE_BYTES = 2**19
  14. logger = get_logger(__name__)
  15. class TensorPartContainer:
  16. """
  17. Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
  18. The class is designed to avoid excessive memory allocation and run all heavy computation in background
  19. :param tensors: local tensors to be split and aggregated
  20. :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
  21. :param compression: optionally compress tensors with this compression algorithm before sending them to peers
  22. :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
  23. :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
  24. :param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
  25. :param prefetch: when compressing, pre-compute this many compressed tensors in background
  26. """
  27. def __init__(
  28. self,
  29. tensors: Sequence[torch.Tensor],
  30. peer_fractions: Sequence[float],
  31. compression: CompressionBase = NoCompression(),
  32. part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
  33. tensor_infos: Optional[Sequence[CompressionInfo]] = None,
  34. return_deltas: bool = True,
  35. prefetch: int = 1,
  36. ):
  37. if tensor_infos is None:
  38. tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
  39. assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
  40. self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
  41. self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
  42. self.total_size = sum(tensor.numel() for tensor in tensors)
  43. self.failed_size = 0
  44. self.return_deltas = return_deltas
  45. self.prefetch = prefetch
  46. self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
  47. self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
  48. self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
  49. self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
  50. self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
  51. self._outputs_consumed = False
  52. self.finished = asyncio.Event()
  53. self.num_parts_by_tensor = []
  54. # split tensor parts in proportion to target_size_by_peer
  55. current_length = 0
  56. current_peer_index = 0
  57. pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
  58. pivots[-1] = self.total_size
  59. for tensor, info in zip(self.local_tensors, self.tensor_infos):
  60. bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
  61. part_size_values = int(part_size_bytes / bytes_per_value)
  62. tensor_parts = tensor.detach().view(-1).split(part_size_values)
  63. self.num_parts_by_tensor.append(len(tensor_parts))
  64. for part_index, part in enumerate(tensor_parts):
  65. part_info = info.get_part(part_index, part_size_values)
  66. if current_length + len(part) > pivots[current_peer_index]:
  67. # switch to next peer; if a part lands between parts of two or
  68. # more peers, assign that part to the peer with highest intersection
  69. prev_peer_index = current_peer_index
  70. peer_intersections = [pivots[current_peer_index] - current_length]
  71. while current_length + len(part) > pivots[current_peer_index]:
  72. current_peer_index += 1
  73. current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
  74. peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
  75. assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
  76. self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
  77. else:
  78. self._input_parts_by_peer[current_peer_index].append((part, part_info))
  79. current_length += len(part)
  80. assert current_length == self.total_size
  81. self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
  82. @torch.no_grad()
  83. def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
  84. """get non-serialized tensor parts for a peer at a given index"""
  85. assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
  86. self._inputs_consumed_by_peer[peer_index] = True
  87. input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
  88. return input_parts
  89. @torch.no_grad()
  90. async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
  91. """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
  92. assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
  93. self._inputs_consumed_by_peer[peer_index] = True
  94. parts_aiter = as_aiter(*self._input_parts_by_peer[peer_index])
  95. async for serialized_part in amap_in_executor(
  96. lambda x_and_info: self.compression.compress(*x_and_info), parts_aiter, max_prefetch=self.prefetch
  97. ):
  98. yield serialized_part
  99. def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
  100. """
  101. register next-in-line part of results received from a given peer for use in iterate_output_tensors
  102. depending on the algorithm, processed part is an average, difference from average or another aggregation
  103. """
  104. if part_index != self._outputs_registered_by_peer[peer_index]:
  105. raise ValueError(
  106. f"Could not register part #{part_index} from peer #{peer_index}, "
  107. f" expected part index: {self._outputs_registered_by_peer[peer_index]}"
  108. )
  109. self._output_parts_by_peer[peer_index].append(part)
  110. self._outputs_registered_by_peer[peer_index] += 1
  111. self._output_part_available[peer_index].set()
  112. def register_failed_reducer(self, peer_index: int):
  113. """
  114. a given peer failed to aggregate a certain part, use our local part instead, keep track of failed parts
  115. """
  116. for part_index in range(self._outputs_registered_by_peer[peer_index], self.num_parts_by_peer[peer_index]):
  117. part_and_info = self._input_parts_by_peer[peer_index][part_index]
  118. part_result_or_delta = torch.zeros_like(part_and_info[0]) if self.return_deltas else part_and_info[0]
  119. self.register_processed_part(peer_index, part_index, part_result_or_delta)
  120. self.failed_size += part_result_or_delta.numel()
  121. async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
  122. """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
  123. assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
  124. self._outputs_consumed = True
  125. peer_index = num_parts_processed = 0
  126. for tensor_index in range(len(self.local_tensors)):
  127. tensor_parts = []
  128. while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
  129. if num_parts_processed >= self.num_parts_by_peer[peer_index]:
  130. num_parts_processed = 0
  131. peer_index += 1
  132. continue
  133. if not self._output_parts_by_peer[peer_index]:
  134. self._output_part_available[peer_index].clear()
  135. await self._output_part_available[peer_index].wait()
  136. if self.finished.is_set():
  137. raise AllreduceException("All-reduce was terminated during iteration")
  138. tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
  139. num_parts_processed += 1
  140. tensor = torch.cat(tensor_parts)
  141. del tensor_parts
  142. yield tensor.reshape(self.local_tensors[tensor_index].shape)
  143. def __del__(self):
  144. self.finalize()
  145. def finalize(self):
  146. """terminate all iterators, delete intermediate data"""
  147. if not self.finished.is_set():
  148. for peer_index in range(self.group_size):
  149. self._inputs_consumed_by_peer[peer_index] = True
  150. self._output_part_available[peer_index].set()
  151. self._input_parts_by_peer[peer_index].clear()
  152. self._output_parts_by_peer[peer_index].clear()
  153. if self.failed_size != 0:
  154. logger.warning(f"Averaging: received {(1. - self.failed_size / self.total_size) * 100:.1f}% results")
  155. self._outputs_consumed = True
  156. self.finished.set()
  157. class TensorPartReducer:
  158. """
  159. Auxiliary data structure responsible for running asynchronous all-reduce
  160. :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
  161. :param num_senders: total number of peers in a given all-reduce group that will send gradients
  162. :note: even if local peer is not sending data, local parts will be used for shape information
  163. """
  164. def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
  165. self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
  166. self.current_part_index = -1 # index in local_parts of the part that should be loaded next
  167. self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
  168. self.accumulator = None # this will contain the sum of current tensor part from group peers
  169. self.denominator = 0.0 # total weight accumulated from all peers for current part
  170. self.current_part_future = asyncio.Future()
  171. self.finished = asyncio.Event()
  172. self.num_parts_received = [0 for _ in range(self.num_senders)]
  173. self.sender_failed_after = [float("inf") for _ in range(self.num_senders)]
  174. self.num_current_senders = self.num_senders
  175. self.reset_accumulators()
  176. def reset_accumulators(self):
  177. """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
  178. assert self.current_part_accumulated_from == self.num_current_senders or self.current_part_index == -1
  179. if self.current_part_index >= self.num_parts - 1:
  180. self.finalize()
  181. return
  182. self.current_part_index += 1
  183. self.current_part_accumulated_from = 0
  184. self.current_part_future = asyncio.Future()
  185. self.num_current_senders = sum(
  186. self.current_part_index < failed_index for failed_index in self.sender_failed_after
  187. )
  188. self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
  189. self.denominator = 0.0
  190. async def accumulate_part(
  191. self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
  192. ) -> torch.Tensor:
  193. """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
  194. assert 0 <= sender_index < self.num_senders, "invalid sender index"
  195. assert 0 <= part_index < self.num_parts, "invalid part index"
  196. self.num_parts_received[sender_index] += 1
  197. while part_index > self.current_part_index:
  198. # wait for previous parts to finish processing ...
  199. await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
  200. if self.finished.is_set():
  201. raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
  202. if self.sender_failed_after[sender_index] != float("inf"):
  203. raise BannedException(f"sender {sender_index} was banned in background")
  204. assert part_index == self.current_part_index
  205. current_part_future = self.current_part_future
  206. if part_index < self.sender_failed_after[sender_index]:
  207. self.accumulator.add_(tensor_part, alpha=weight)
  208. self.current_part_accumulated_from += 1
  209. self.denominator += weight
  210. self.check_current_part_finished()
  211. return await current_part_future
  212. def on_sender_failed(self, sender_index: int):
  213. """Exclude that sender's data for averaging any parts that it did not submit yet."""
  214. self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
  215. if self.finished.is_set():
  216. return
  217. if self.current_part_index == self.num_parts_received[sender_index]:
  218. self.num_current_senders -= 1
  219. self.check_current_part_finished()
  220. def check_current_part_finished(self):
  221. assert self.current_part_accumulated_from <= self.num_current_senders
  222. if self.current_part_accumulated_from == self.num_current_senders:
  223. self.current_part_future.set_result(self.accumulator.div_(self.denominator))
  224. self.reset_accumulators()
  225. def finalize(self):
  226. if not self.finished.is_set():
  227. if hasattr(self, "current_part_future"):
  228. self.current_part_future.cancel()
  229. del self.accumulator
  230. self.finished.set()
  231. if self.num_parts != 0 and self.num_senders != 0:
  232. parts_expected = self.num_parts * self.num_senders
  233. parts_received = sum(self.num_parts_received)
  234. if parts_expected != parts_received:
  235. logger.warning(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
  236. def __del__(self):
  237. self.finalize()
  238. class AllreduceException(Exception):
  239. """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""
  240. class BannedException(AllreduceException):
  241. """An exception that indicates that a given sender was banned and will no longer be aggregated"""