test_allreduce_fault_tolerance.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import asyncio
  3. from enum import Enum, auto
  4. from typing import AsyncIterator
  5. import pytest
  6. import torch
  7. import hivemind
  8. from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
  9. from hivemind.averaging.averager import *
  10. from hivemind.averaging.group_info import GroupInfo
  11. from hivemind.averaging.load_balancing import load_balance_peers
  12. from hivemind.averaging.matchmaking import MatchmakingException
  13. from hivemind.proto import averaging_pb2
  14. from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
  15. from hivemind.utils.logging import get_logger
  16. logger = get_logger(__name__)
  17. class Fault(Enum):
  18. NONE = auto()
  19. FAIL_BEFORE = auto()
  20. FAIL_SENDING = auto()
  21. SLOW_SENDING = auto()
  22. FAIL_REDUCING = auto()
  23. SLOW_REDUCING = auto()
  24. CANCEL = auto()
  25. class FaultyAverager(hivemind.DecentralizedAverager):
  26. def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
  27. self.fault = fault
  28. super().__init__(*args, **kwargs)
  29. async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  30. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  31. try:
  32. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  33. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  34. modes = tuple(map(AveragingMode, mode_ids))
  35. download_bandwidths = [
  36. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  37. ]
  38. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  39. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  40. )
  41. if self.fault == Fault.FAIL_BEFORE:
  42. raise Exception("Oops, I failed!")
  43. async with enter_asynchronously(self.get_tensors()) as local_tensors:
  44. allreduce = FaultyAllReduceRunner(
  45. p2p=self._p2p,
  46. servicer_type=type(self),
  47. prefix=self.prefix,
  48. group_id=group_info.group_id,
  49. tensors=local_tensors,
  50. ordered_peer_ids=group_info.peer_ids,
  51. peer_fractions=peer_fractions,
  52. gathered=user_gathered,
  53. modes=modes,
  54. fault=self.fault,
  55. **kwargs,
  56. )
  57. self._running_groups[group_info.group_id].set_result(allreduce)
  58. # ^--- maybe this can be extracted into a method that checks if register_... context is active.
  59. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  60. iter_results = allreduce.run()
  61. async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
  62. # all-reduce is performed asynchronously while iterating
  63. tensor.add_(update, alpha=self._averaging_alpha)
  64. self._state_updated.set()
  65. else:
  66. async for _ in allreduce: # trigger all-reduce by iterating
  67. raise ValueError("aux peers should not receive averaged tensors")
  68. return allreduce.gathered
  69. except BaseException as e:
  70. logger.exception(e)
  71. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  72. class FaultyAllReduceRunner(AllReduceRunner):
  73. def __init__(self, *args, fault: Fault, **kwargs):
  74. self.fault = fault
  75. super().__init__(*args, **kwargs)
  76. async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
  77. if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
  78. async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
  79. yield message
  80. if i == 2:
  81. if self.fault == Fault.FAIL_SENDING:
  82. yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
  83. break
  84. else:
  85. await asyncio.sleep(10)
  86. elif self.fault == Fault.CANCEL:
  87. yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
  88. else:
  89. async for message in super().rpc_aggregate_part(stream, context):
  90. yield message
  91. async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
  92. parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
  93. first_part = await anext(parts_aiter)
  94. yield averaging_pb2.AveragingData(
  95. code=averaging_pb2.PART_FOR_AVERAGING,
  96. group_id=self.group_id,
  97. tensor_part=first_part,
  98. weight=self.weight,
  99. )
  100. if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
  101. last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
  102. if peer_index == last_reducer_index:
  103. if self.fault == Fault.FAIL_SENDING:
  104. raise Exception("Oops, I failed!")
  105. else:
  106. await asyncio.sleep(10)
  107. async for part in parts_aiter:
  108. yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
  109. @pytest.mark.forked
  110. @pytest.mark.parametrize(
  111. "fault0, fault1",
  112. [
  113. (Fault.NONE, Fault.FAIL_BEFORE),
  114. (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
  115. (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
  116. (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
  117. (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
  118. (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
  119. (Fault.NONE, Fault.CANCEL),
  120. ],
  121. )
  122. def test_fault_tolerance(fault0: Fault, fault1: Fault):
  123. def _make_tensors():
  124. return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
  125. dht = hivemind.DHT(start=True)
  126. averagers = []
  127. for i in range(5):
  128. averager = FaultyAverager(
  129. _make_tensors(),
  130. hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
  131. prefix="test",
  132. request_timeout=0.3,
  133. min_matchmaking_time=1.0,
  134. next_chunk_timeout=0.5,
  135. allreduce_timeout=5,
  136. part_size_bytes=2**16,
  137. client_mode=(i == 1),
  138. start=True,
  139. fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
  140. )
  141. averagers.append(averager)
  142. ref_numerators = [0, 0, 0, 0]
  143. ref_denominator = 0
  144. for averager in averagers:
  145. if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
  146. with averager.get_tensors() as tensors:
  147. for i, tensor in enumerate(tensors):
  148. ref_numerators[i] = ref_numerators[i] + tensor.clone()
  149. ref_denominator += 1
  150. ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
  151. flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
  152. flat_local_tensors = []
  153. for averager in averagers:
  154. with averager.get_tensors() as tensors:
  155. flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
  156. futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
  157. for i, averager in enumerate(averagers):
  158. if averager.fault == Fault.CANCEL:
  159. futures[i].cancel()
  160. for future in futures[2:]:
  161. assert future.result()
  162. for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
  163. with averager.get_tensors() as tensors:
  164. flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
  165. diff_with_reference = abs(flat_ref - flat_tensors)
  166. if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
  167. assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
  168. assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
  169. elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
  170. diff_to_reference = abs(flat_ref - flat_tensors)
  171. diff_to_local = abs(prev_local_tensors - flat_tensors)
  172. assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
  173. assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
  174. elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
  175. pass # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
  176. elif fault0 == Fault.NONE: # only peer1 in client mode may have failed
  177. assert diff_with_reference.max() < 1e-5
  178. else:
  179. assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
  180. for averager in averagers:
  181. averager.shutdown()