test_allreduce_fault_tolerance.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from __future__ import annotations
  2. from enum import Enum, auto
  3. import pytest
  4. import hivemind
  5. from hivemind.averaging.averager import *
  6. from hivemind.averaging.group_info import GroupInfo
  7. from hivemind.averaging.load_balancing import load_balance_peers
  8. from hivemind.averaging.matchmaking import MatchmakingException
  9. from hivemind.proto import averaging_pb2
  10. from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger(__name__)
  13. class Fault(Enum):
  14. NONE = auto()
  15. FAIL_BEFORE = auto()
  16. FAIL_SENDING = auto()
  17. SLOW_SENDING = auto()
  18. FAIL_REDUCING = auto()
  19. SLOW_REDUCING = auto()
  20. CANCEL = auto()
  21. class FaultyAverager(hivemind.DecentralizedAverager):
  22. def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
  23. self.fault = fault
  24. super().__init__(*args, **kwargs)
  25. async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  26. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  27. try:
  28. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  29. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  30. modes = tuple(map(AveragingMode, mode_ids))
  31. download_bandwidths = [
  32. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  33. ]
  34. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  35. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  36. )
  37. if self.fault == Fault.FAIL_BEFORE:
  38. raise Exception("Oops, I failed!")
  39. async with enter_asynchronously(self.get_tensors()) as local_tensors:
  40. allreduce = FaultyAllReduceRunner(
  41. p2p=self._p2p,
  42. servicer_type=type(self),
  43. prefix=self.prefix,
  44. group_id=group_info.group_id,
  45. tensors=local_tensors,
  46. ordered_peer_ids=group_info.peer_ids,
  47. peer_fractions=peer_fractions,
  48. modes=modes,
  49. fault=self.fault,
  50. **kwargs,
  51. )
  52. self._running_groups[group_info.group_id].set_result(allreduce)
  53. # TODO maybe this can be extracted into a method that checks if register_... context is active.
  54. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  55. iter_results = allreduce.run()
  56. async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
  57. # all-reduce is performed asynchronously while iterating
  58. tensor.add_(update, alpha=self._averaging_alpha)
  59. self._state_updated.set()
  60. else:
  61. async for _ in allreduce: # trigger all-reduce by iterating
  62. raise ValueError("aux peers should not receive averaged tensors")
  63. return user_gathered
  64. except BaseException as e:
  65. logger.exception(e)
  66. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  67. class FaultyAllReduceRunner(AllReduceRunner):
  68. def __init__(self, *args, fault: Fault, **kwargs):
  69. self.fault = fault
  70. super().__init__(*args, **kwargs)
  71. async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
  72. if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
  73. async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
  74. yield message
  75. if i == 2:
  76. if self.fault == Fault.FAIL_SENDING:
  77. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  78. break
  79. else:
  80. await asyncio.sleep(10)
  81. elif self.fault == Fault.CANCEL:
  82. yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
  83. else:
  84. async for message in super().rpc_aggregate_part(stream, context):
  85. yield message
  86. async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
  87. parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
  88. first_part = await anext(parts_aiter)
  89. yield averaging_pb2.AveragingData(
  90. code=averaging_pb2.PART_FOR_AVERAGING,
  91. group_id=self.group_id,
  92. tensor_part=first_part,
  93. weight=self.weight,
  94. )
  95. if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
  96. last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
  97. if peer_index == last_reducer_index:
  98. if self.fault == Fault.FAIL_SENDING:
  99. raise Exception("Oops, I failed!")
  100. else:
  101. await asyncio.sleep(10)
  102. async for part in parts_aiter:
  103. yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
  104. @pytest.mark.forked
  105. @pytest.mark.parametrize(
  106. "fault0, fault1",
  107. [
  108. (Fault.NONE, Fault.FAIL_BEFORE),
  109. (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
  110. (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
  111. (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
  112. (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
  113. (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
  114. (Fault.NONE, Fault.CANCEL),
  115. ],
  116. )
  117. def test_fault_tolerance(fault0: Fault, fault1: Fault):
  118. def _make_tensors():
  119. return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
  120. dht = hivemind.DHT(start=True)
  121. averagers = []
  122. for i in range(5):
  123. averager = FaultyAverager(
  124. _make_tensors(),
  125. hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
  126. prefix="test",
  127. request_timeout=0.3,
  128. min_matchmaking_time=1.0,
  129. next_chunk_timeout=0.5,
  130. allreduce_timeout=5,
  131. part_size_bytes=2**16,
  132. client_mode=(i == 1),
  133. start=True,
  134. fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
  135. )
  136. averagers.append(averager)
  137. ref_numerators = [0, 0, 0, 0]
  138. ref_denominator = 0
  139. for averager in averagers:
  140. if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
  141. with averager.get_tensors() as tensors:
  142. for i, tensor in enumerate(tensors):
  143. ref_numerators[i] = ref_numerators[i] + tensor.clone()
  144. ref_denominator += 1
  145. ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
  146. flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
  147. flat_local_tensors = []
  148. for averager in averagers:
  149. with averager.get_tensors() as tensors:
  150. flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
  151. futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
  152. for i, averager in enumerate(averagers):
  153. if averager.fault == Fault.CANCEL:
  154. futures[i].cancel()
  155. for future in futures[2:]:
  156. assert future.result()
  157. for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
  158. with averager.get_tensors() as tensors:
  159. flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
  160. diff_with_reference = abs(flat_ref - flat_tensors)
  161. if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
  162. assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
  163. assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
  164. elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
  165. diff_to_reference = abs(flat_ref - flat_tensors)
  166. diff_to_local = abs(prev_local_tensors - flat_tensors)
  167. assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
  168. assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
  169. elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
  170. pass # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
  171. elif fault0 == Fault.NONE: # only peer1 in client mode may have failed
  172. assert diff_with_reference.max() < 1e-5
  173. else:
  174. assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
  175. for averager in averagers:
  176. averager.shutdown()