123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import asyncio
- import re
- import random
- from typing import Optional, List, Tuple
- import numpy as np
- from hivemind.dht import DHT
- from hivemind.client.averaging.allreduce import AllReduceRunner
- from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
- GroupKey = str
- GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$') # e.g. bert_exp4_averaging.0b01001101
- logger = get_logger(__name__)
- def is_valid_group(maybe_group: str) -> bool:
- """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
- return bool(GROUP_PATTERN.fullmatch(maybe_group))
- class GroupKeyManager:
- """
- Utility class that declares and fetches averaging-related keys using a DHT
- """
- RESERVED_KEY_FOR_NBITS = '::NBITS'
- def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
- target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
- nbits_expiration: float = 60):
- assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
- if initial_group_bits is None:
- search_result = dht.get(f"{prefix}.0b", latest=True)
- initial_group_bits = self.get_suggested_nbits(search_result) or ''
- self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
- self.target_group_size = target_group_size
- self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
- self.excessive_size = excessive_size or target_group_size * 3
- self.nbits_expiration = nbits_expiration
- self.suggested_nbits: Optional[int] = None
- @property
- def current_key(self) -> GroupKey:
- return f"{self.prefix}.0b{self.group_bits}"
- async def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float,
- looking_for_group: bool = True) -> bool:
- """
- Add (or remove) the averager to a given allreduce bucket
- :param group_key: allreduce group key, e.g. my_averager.0b011011101
- :param endpoint: averager public endpoint for incoming requests
- :param expiration_time: intent to run allreduce before this timestamp
- :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
- If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
- :return: True if declared, False if declaration was rejected by DHT peers
- :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
- :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
- """
- expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float('inf')))
- return await self.dht.store(key=group_key, subkey=endpoint, value=looking_for_group,
- expiration_time=expiration_time, return_future=True)
- async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
- """
- Find and return averagers that were declared with a given all-reduce key
- :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
- :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
- if False, return all averagers under a given group_key regardless of value
- :return: endpoints and expirations of every matching averager
- """
- assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
- result = await self.dht.get(group_key, latest=True, return_future=True)
- if result is None or not isinstance(result.value, dict):
- logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
- return []
- averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
- if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
- num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
- suggested_nbits = self.get_suggested_nbits(result)
- if suggested_nbits is not None and suggested_nbits != self.suggested_nbits:
- self.suggested_nbits = suggested_nbits
- logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
- elif num_active_averagers >= self.excessive_size:
- self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
- logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
- return averagers
- async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
- """ notify other peers that they can run averaging at this depth """
- return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
- expiration_time=expiration_time, return_future=True)
- @classmethod
- def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
- if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
- and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
- return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
- else:
- return None
- async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
- """ this function is triggered every time an averager finds an allreduce group """
- rng = random.Random(allreduce_group.group_key_seed)
- index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
- generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
- nbits = int(np.ceil(np.log2(self.target_group_size)))
- new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
- self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):]
- logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
- if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
- asyncio.create_task(self.notify_stragglers_on_success())
- if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
- num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
- self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
- self.group_bits = self.group_bits[-self.suggested_nbits:]
- self.suggested_nbits = None
- async def update_key_on_not_enough_peers(self):
- """ this function is triggered whenever averager fails to assemble group within timeout """
- new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
- prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:]
- if self.group_bits != prev_nbits:
- logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
- self.suggested_nbits = None
- async def notify_stragglers_on_success(self):
- """ Find averagers that have fewer nbits and redirect them to your current nbits """
- for nbits in reversed(range(1, len(self.group_bits) - 1)):
- preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
- preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
- if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
- await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
- break
- root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
- if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
- await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
|