key_manager.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import asyncio
  2. import re
  3. import random
  4. from typing import Optional, List, Tuple
  5. import numpy as np
  6. from hivemind.dht import DHT
  7. from hivemind.client.averaging.allreduce import AllReduceRunner
  8. from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
  9. GroupKey = str
  10. GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$') # e.g. bert_exp4_averaging.0b01001101
  11. logger = get_logger(__name__)
  12. def is_valid_group(maybe_group: str) -> bool:
  13. """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
  14. return bool(GROUP_PATTERN.fullmatch(maybe_group))
  15. class GroupKeyManager:
  16. """
  17. Utility class that declares and fetches averaging-related keys using a DHT
  18. """
  19. RESERVED_KEY_FOR_NBITS = '::NBITS'
  20. def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
  21. target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
  22. nbits_expiration: float = 60):
  23. assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
  24. if initial_group_bits is None:
  25. search_result = dht.get(f"{prefix}.0b", latest=True)
  26. initial_group_bits = self.get_suggested_nbits(search_result) or ''
  27. self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
  28. self.target_group_size = target_group_size
  29. self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
  30. self.excessive_size = excessive_size or target_group_size * 3
  31. self.nbits_expiration = nbits_expiration
  32. self.suggested_nbits: Optional[int] = None
  33. @property
  34. def current_key(self) -> GroupKey:
  35. return f"{self.prefix}.0b{self.group_bits}"
  36. async def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float,
  37. looking_for_group: bool = True) -> bool:
  38. """
  39. Add (or remove) the averager to a given allreduce bucket
  40. :param group_key: allreduce group key, e.g. my_averager.0b011011101
  41. :param endpoint: averager public endpoint for incoming requests
  42. :param expiration_time: intent to run allreduce before this timestamp
  43. :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
  44. If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
  45. :return: True if declared, False if declaration was rejected by DHT peers
  46. :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
  47. :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
  48. """
  49. expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float('inf')))
  50. return await self.dht.store(key=group_key, subkey=endpoint, value=looking_for_group,
  51. expiration_time=expiration_time, return_future=True)
  52. async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
  53. """
  54. Find and return averagers that were declared with a given all-reduce key
  55. :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
  56. :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
  57. if False, return all averagers under a given group_key regardless of value
  58. :return: endpoints and expirations of every matching averager
  59. """
  60. assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
  61. result = await self.dht.get(group_key, latest=True, return_future=True)
  62. if result is None or not isinstance(result.value, dict):
  63. logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
  64. return []
  65. averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
  66. if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
  67. num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
  68. suggested_nbits = self.get_suggested_nbits(result)
  69. if suggested_nbits is not None and suggested_nbits != self.suggested_nbits:
  70. self.suggested_nbits = suggested_nbits
  71. logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
  72. elif num_active_averagers >= self.excessive_size:
  73. self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
  74. logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
  75. return averagers
  76. async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
  77. """ notify other peers that they can run averaging at this depth """
  78. return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
  79. expiration_time=expiration_time, return_future=True)
  80. @classmethod
  81. def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
  82. if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
  83. and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
  84. return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
  85. else:
  86. return None
  87. async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
  88. """ this function is triggered every time an averager finds an allreduce group """
  89. rng = random.Random(allreduce_group.group_key_seed)
  90. index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
  91. generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
  92. nbits = int(np.ceil(np.log2(self.target_group_size)))
  93. new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
  94. self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):]
  95. logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
  96. if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
  97. asyncio.create_task(self.notify_stragglers_on_success())
  98. if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
  99. num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
  100. self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
  101. self.group_bits = self.group_bits[-self.suggested_nbits:]
  102. self.suggested_nbits = None
  103. async def update_key_on_not_enough_peers(self):
  104. """ this function is triggered whenever averager fails to assemble group within timeout """
  105. new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
  106. prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:]
  107. if self.group_bits != prev_nbits:
  108. logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
  109. self.suggested_nbits = None
  110. async def notify_stragglers_on_success(self):
  111. """ Find averagers that have fewer nbits and redirect them to your current nbits """
  112. for nbits in reversed(range(1, len(self.group_bits) - 1)):
  113. preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
  114. preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
  115. if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
  116. await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
  117. break
  118. root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
  119. if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
  120. await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)