partition.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """
  2. Auxiliary data structures for AllReduceRunner
  3. """
  4. import asyncio
  5. from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator, Type, Deque, List
  6. from collections import deque
  7. import torch
  8. import numpy as np
  9. from hivemind.proto.runtime_pb2 import CompressionType, Tensor
  10. from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
  11. from hivemind.utils.asyncio import amap_in_executor
  12. T = TypeVar("T")
  13. DEFAULT_PART_SIZE_BYTES = 2 ** 20
  14. class TensorPartContainer:
  15. """
  16. Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
  17. The class is designed to avoid excessive memory allocation and run all heavy computation in background
  18. :param tensors: local tensors to be split and aggregated
  19. :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
  20. :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
  21. :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
  22. :param prefetch: when compressing, pre-compute this many compressed tensors in background
  23. """
  24. def __init__(
  25. self,
  26. tensors: Sequence[torch.Tensor],
  27. peer_fractions: Sequence[float],
  28. compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
  29. part_size_bytes: int = 2 ** 20,
  30. prefetch: int = 1,
  31. ):
  32. if not isinstance(compression_type, Sequence):
  33. compression_type = [compression_type] * len(tensors)
  34. assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
  35. self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
  36. self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
  37. self.total_size = sum(tensor.numel() for tensor in tensors)
  38. self._input_parts_by_peer: List[Deque[Tuple[torch.Tensor, Type[CompressionType]]]] = [
  39. deque() for _ in range(self.group_size)
  40. ]
  41. self._output_parts_by_peer: List[Deque[torch.Tensor]] = [deque() for _ in range(self.group_size)]
  42. self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
  43. self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
  44. self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
  45. self._outputs_consumed = False
  46. self.finished = asyncio.Event()
  47. self.num_parts_by_tensor = []
  48. # split tensor parts in proportion to target_size_by_peer
  49. current_length = 0
  50. current_peer_index = 0
  51. pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
  52. pivots[-1] = self.total_size
  53. for tensor, tensor_compression in zip(self.local_tensors, compression_type):
  54. part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
  55. tensor_parts = tensor.detach().view(-1).split(part_size_values)
  56. self.num_parts_by_tensor.append(len(tensor_parts))
  57. for part in tensor_parts:
  58. if current_length + len(part) > pivots[current_peer_index]:
  59. # switch to next peer; if a part lands between parts of two or
  60. # more peers, assign that part to the peer with highest intersection
  61. prev_peer_index = current_peer_index
  62. peer_intersections = [pivots[current_peer_index] - current_length]
  63. while current_length + len(part) > pivots[current_peer_index]:
  64. current_peer_index += 1
  65. current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
  66. peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
  67. assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
  68. self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
  69. else:
  70. self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
  71. current_length += len(part)
  72. assert current_length == self.total_size
  73. self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
  74. @torch.no_grad()
  75. def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
  76. """get non-serialized tensor parts for a peer at a given index"""
  77. assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
  78. self._inputs_consumed_by_peer[peer_index] = True
  79. input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
  80. self._input_parts_by_peer[peer_index].clear()
  81. return input_parts
  82. @torch.no_grad()
  83. async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
  84. """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
  85. assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
  86. self._inputs_consumed_by_peer[peer_index] = True
  87. async def _aiterate_parts():
  88. for _ in range(self.num_parts_by_peer[peer_index]):
  89. yield self._input_parts_by_peer[peer_index].popleft()
  90. async for serialized_part in amap_in_executor(
  91. lambda x_and_compr: serialize_torch_tensor(*x_and_compr), _aiterate_parts(), max_prefetch=self.prefetch
  92. ):
  93. yield serialized_part
  94. def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
  95. """
  96. register next-in-line part of results received from a given peer for use in iterate_output_tensors
  97. depending on the algorithm, processed part is an average, difference from average or another aggregation
  98. """
  99. if part_index != self._outputs_registered_by_peer[peer_index]:
  100. raise ValueError(
  101. f"Could not register part #{part_index} from peer #{peer_index}, "
  102. f" expected part index: {self._outputs_registered_by_peer[peer_index]}"
  103. )
  104. self._output_parts_by_peer[peer_index].append(part)
  105. self._outputs_registered_by_peer[peer_index] += 1
  106. self._output_part_available[peer_index].set()
  107. async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
  108. """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
  109. assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
  110. self._outputs_consumed = True
  111. peer_index = num_parts_processed = 0
  112. for tensor_index in range(len(self.local_tensors)):
  113. tensor_parts: List[torch.Tensor] = []
  114. while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
  115. if num_parts_processed >= self.num_parts_by_peer[peer_index]:
  116. num_parts_processed = 0
  117. peer_index += 1
  118. continue
  119. if not self._output_parts_by_peer[peer_index]:
  120. self._output_part_available[peer_index].clear()
  121. await self._output_part_available[peer_index].wait()
  122. if self.finished.is_set():
  123. raise AllreduceException("All-reduce was terminated during iteration.")
  124. tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
  125. num_parts_processed += 1
  126. tensor = torch.cat(tensor_parts)
  127. del tensor_parts
  128. yield tensor.reshape(self.local_tensors[tensor_index].shape)
  129. def __del__(self):
  130. self.finalize()
  131. def finalize(self):
  132. """terminate all iterators, delete intermediate data"""
  133. if not self.finished.is_set():
  134. for peer_index in range(self.group_size):
  135. self._inputs_consumed_by_peer[peer_index] = True
  136. self._input_parts_by_peer[peer_index].clear()
  137. self._output_parts_by_peer[peer_index].clear()
  138. self._output_part_available[peer_index].set()
  139. self._outputs_consumed = True
  140. self.finished.set()
  141. class TensorPartReducer:
  142. """
  143. Auxiliary data structure responsible for running asynchronous all-reduce
  144. :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
  145. :param num_senders: total number of peers in a given all-reduce group that will send gradients
  146. :param weights: relative importance of each sender, used for weighted average (default = equal weights)
  147. :note: even if local peer is not sending data, local parts will be used for shape information
  148. """
  149. def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
  150. self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
  151. self.weights = tuple(weights or (1 for _ in range(num_senders)))
  152. assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
  153. assert all(isinstance(weight, (int, float)) for weight in self.weights)
  154. self.current_part_index = -1 # index in local_parts of the part that should be loaded next
  155. self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
  156. self.accumulator: Optional[torch.Tensor] = None # contains the sum of current tensor part from group peers
  157. self.denominator = 0.0 # total weight accumulated from all peers for current part
  158. self.current_part_future = asyncio.Future()
  159. self.finished = asyncio.Event()
  160. self.reset_accumulators()
  161. def reset_accumulators(self):
  162. """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
  163. assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
  164. if self.current_part_index >= self.num_parts - 1:
  165. self.finalize()
  166. return
  167. self.current_part_index += 1
  168. self.current_part_accumulated_from = 0
  169. self.current_part_future = asyncio.Future()
  170. self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
  171. self.denominator = 0.0
  172. async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
  173. """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
  174. assert 0 <= sender_index < self.num_senders, "invalid sender index"
  175. assert 0 <= part_index < self.num_parts, "invalid part index"
  176. while part_index > self.current_part_index:
  177. # wait for previous parts to finish processing ...
  178. await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
  179. if self.finished.is_set():
  180. raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
  181. assert part_index == self.current_part_index
  182. current_part_future = self.current_part_future
  183. self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
  184. self.denominator += self.weights[sender_index]
  185. self.current_part_accumulated_from += 1
  186. assert self.current_part_accumulated_from <= self.num_senders
  187. if self.current_part_accumulated_from == self.num_senders:
  188. current_part_future.set_result(self.accumulator.div_(self.denominator))
  189. self.reset_accumulators()
  190. return await current_part_future
  191. def finalize(self):
  192. if not self.finished.is_set():
  193. if hasattr(self, "current_part_future"):
  194. self.current_part_future.cancel()
  195. del self.accumulator
  196. self.finished.set()
  197. def __del__(self):
  198. self.finalize()
  199. class AllreduceException(Exception):
  200. """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""