partition.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """
  2. Auxiliary data structures for AllReduceRunner
  3. """
  4. import asyncio
  5. from collections import deque
  6. from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
  7. import numpy as np
  8. import torch
  9. from hivemind.averaging.accumulators import AccumulatorFactory
  10. from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
  11. from hivemind.proto import runtime_pb2
  12. from hivemind.utils.asyncio import amap_in_executor
  13. T = TypeVar("T")
  14. DEFAULT_PART_SIZE_BYTES = 2 ** 16
  15. class TensorPartContainer:
  16. """
  17. Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
  18. The class is designed to avoid excessive memory allocation and run all heavy computation in background
  19. :param tensors: local tensors to be split and aggregated
  20. :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
  21. :param compression: optionally compress tensors with this compression algorithm before sending them to peers
  22. :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
  23. :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
  24. :param prefetch: when compressing, pre-compute this many compressed tensors in background
  25. """
  26. def __init__(
  27. self,
  28. tensors: Sequence[torch.Tensor],
  29. peer_fractions: Sequence[float],
  30. compression: CompressionBase = NoCompression(),
  31. part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
  32. tensor_infos: Optional[Sequence[CompressionInfo]] = None,
  33. prefetch: int = 5,
  34. ):
  35. if tensor_infos is None:
  36. tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
  37. assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
  38. self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
  39. self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
  40. self.total_size = sum(tensor.numel() for tensor in tensors)
  41. self.prefetch = prefetch
  42. self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
  43. self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
  44. self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
  45. self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
  46. self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
  47. self._outputs_consumed = False
  48. self.finished = asyncio.Event()
  49. self.num_parts_by_tensor = []
  50. # split tensor parts in proportion to target_size_by_peer
  51. current_length = 0
  52. current_peer_index = 0
  53. pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
  54. pivots[-1] = self.total_size
  55. for tensor, info in zip(self.local_tensors, self.tensor_infos):
  56. bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
  57. part_size_values = int(part_size_bytes / bytes_per_value)
  58. tensor_parts = tensor.detach().view(-1).split(part_size_values)
  59. self.num_parts_by_tensor.append(len(tensor_parts))
  60. for part_index, part in enumerate(tensor_parts):
  61. part_info = info.get_part(part_index, part_size_values)
  62. if current_length + len(part) > pivots[current_peer_index]:
  63. # switch to next peer; if a part lands between parts of two or
  64. # more peers, assign that part to the peer with highest intersection
  65. prev_peer_index = current_peer_index
  66. peer_intersections = [pivots[current_peer_index] - current_length]
  67. while current_length + len(part) > pivots[current_peer_index]:
  68. current_peer_index += 1
  69. current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
  70. peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
  71. assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
  72. self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
  73. else:
  74. self._input_parts_by_peer[current_peer_index].append((part, part_info))
  75. current_length += len(part)
  76. assert current_length == self.total_size
  77. self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
  78. @torch.no_grad()
  79. def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
  80. """get non-serialized tensor parts for a peer at a given index"""
  81. assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
  82. self._inputs_consumed_by_peer[peer_index] = True
  83. input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
  84. self._input_parts_by_peer[peer_index].clear()
  85. return input_parts
  86. @torch.no_grad()
  87. async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
  88. """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
  89. assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
  90. self._inputs_consumed_by_peer[peer_index] = True
  91. async def _aiterate_parts():
  92. for _ in range(self.num_parts_by_peer[peer_index]):
  93. yield self._input_parts_by_peer[peer_index].popleft()
  94. async for serialized_part in amap_in_executor(
  95. lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
  96. ):
  97. yield serialized_part
  98. def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
  99. """
  100. register next-in-line part of results received from a given peer for use in iterate_output_tensors
  101. depending on the algorithm, processed part is an average, difference from average or another aggregation
  102. """
  103. if part_index != self._outputs_registered_by_peer[peer_index]:
  104. raise ValueError(
  105. f"Could not register part #{part_index} from peer #{peer_index}, "
  106. f" expected part index: {self._outputs_registered_by_peer[peer_index]}"
  107. )
  108. self._output_parts_by_peer[peer_index].append(part)
  109. self._outputs_registered_by_peer[peer_index] += 1
  110. self._output_part_available[peer_index].set()
  111. async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
  112. """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
  113. assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
  114. self._outputs_consumed = True
  115. peer_index = num_parts_processed = 0
  116. for tensor_index in range(len(self.local_tensors)):
  117. tensor_parts = []
  118. while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
  119. if num_parts_processed >= self.num_parts_by_peer[peer_index]:
  120. num_parts_processed = 0
  121. peer_index += 1
  122. continue
  123. if not self._output_parts_by_peer[peer_index]:
  124. self._output_part_available[peer_index].clear()
  125. await self._output_part_available[peer_index].wait()
  126. if self.finished.is_set():
  127. raise AllreduceException("All-reduce was terminated during iteration.")
  128. tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
  129. num_parts_processed += 1
  130. tensor = torch.cat(tensor_parts)
  131. del tensor_parts
  132. yield tensor.reshape(self.local_tensors[tensor_index].shape)
  133. def __del__(self):
  134. self.finalize()
  135. def finalize(self):
  136. """terminate all iterators, delete intermediate data"""
  137. if not self.finished.is_set():
  138. for peer_index in range(self.group_size):
  139. self._inputs_consumed_by_peer[peer_index] = True
  140. self._input_parts_by_peer[peer_index].clear()
  141. self._output_parts_by_peer[peer_index].clear()
  142. self._output_part_available[peer_index].set()
  143. self._outputs_consumed = True
  144. self.finished.set()
  145. class TensorPartReducer:
  146. """
  147. Auxiliary data structure responsible for running asynchronous all-reduce
  148. :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
  149. :param num_senders: total number of peers in a given all-reduce group that will send gradients
  150. :param weights: relative importance of each sender, used for weighted average (default = equal weights)
  151. :note: even if local peer is not sending data, local parts will be used for shape information
  152. """
  153. def __init__(
  154. self,
  155. part_shapes: Sequence[torch.Size],
  156. num_senders: int,
  157. *,
  158. weights: Optional[Sequence[float]],
  159. accumulator_factory: AccumulatorFactory,
  160. ):
  161. self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
  162. self.weights = tuple(weights or (1 for _ in range(num_senders)))
  163. assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
  164. assert all(isinstance(weight, (int, float)) for weight in self.weights)
  165. self.current_part_index = -1 # index in local_parts of the part that should be loaded next
  166. self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
  167. self.current_part_future = asyncio.Future()
  168. self.accumulator_factory = accumulator_factory
  169. self.accumulator = None
  170. self.finished = asyncio.Event()
  171. self.reset_accumulators()
  172. def reset_accumulators(self):
  173. """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
  174. assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
  175. if self.current_part_index >= self.num_parts - 1:
  176. self.finalize()
  177. return
  178. self.current_part_index += 1
  179. self.current_part_accumulated_from = 0
  180. self.current_part_future = asyncio.Future()
  181. self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
  182. async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
  183. """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
  184. assert 0 <= sender_index < self.num_senders, "invalid sender index"
  185. assert 0 <= part_index < self.num_parts, "invalid part index"
  186. while part_index > self.current_part_index:
  187. # wait for previous parts to finish processing ...
  188. await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
  189. if self.finished.is_set():
  190. raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
  191. assert part_index == self.current_part_index
  192. current_part_future = self.current_part_future
  193. self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
  194. self.current_part_accumulated_from += 1
  195. assert self.current_part_accumulated_from <= self.num_senders
  196. if self.current_part_accumulated_from == self.num_senders:
  197. current_part_future.set_result(self.accumulator.reduce())
  198. self.reset_accumulators()
  199. return await current_part_future
  200. def finalize(self):
  201. if not self.finished.is_set():
  202. if hasattr(self, "current_part_future"):
  203. self.current_part_future.cancel()
  204. self.accumulator = None
  205. self.finished.set()
  206. def __del__(self):
  207. self.finalize()
  208. class AllreduceException(Exception):
  209. """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""