|
@@ -1,4 +1,3 @@
|
|
|
-import asyncio
|
|
|
import random
|
|
|
import re
|
|
|
from typing import List, Optional, Tuple
|
|
@@ -25,31 +24,17 @@ class GroupKeyManager:
|
|
|
Utility class that declares and fetches averaging-related keys using a DHT
|
|
|
"""
|
|
|
|
|
|
- RESERVED_KEY_FOR_NBITS = "::NBITS"
|
|
|
-
|
|
|
def __init__(
|
|
|
self,
|
|
|
dht: DHT,
|
|
|
prefix: str,
|
|
|
- initial_group_bits: Optional[str],
|
|
|
+ initial_group_bits: str,
|
|
|
target_group_size: int,
|
|
|
- insufficient_size: Optional[int] = None,
|
|
|
- excessive_size: Optional[int] = None,
|
|
|
- nbits_expiration: float = 60,
|
|
|
- nbits_rewrite_grace_period: float = 15,
|
|
|
):
|
|
|
- assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
|
|
|
- if initial_group_bits is None:
|
|
|
- search_result = dht.get(f"{prefix}.0b", latest=True)
|
|
|
- initial_group_nbits = self.get_suggested_nbits(search_result) or 0
|
|
|
- initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
|
|
|
+ assert all(bit in "01" for bit in initial_group_bits)
|
|
|
self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
|
|
|
- self.peer_id = dht.peer_id
|
|
|
self.target_group_size = target_group_size
|
|
|
- self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
|
|
|
- self.excessive_size = excessive_size or target_group_size * 3
|
|
|
- self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
|
|
|
- self.suggested_nbits: Optional[int] = None
|
|
|
+ self.peer_id = dht.peer_id
|
|
|
|
|
|
@property
|
|
|
def current_key(self) -> GroupKey:
|
|
@@ -93,51 +78,16 @@ class GroupKeyManager:
|
|
|
if result is None or not isinstance(result.value, dict):
|
|
|
logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
|
|
|
return []
|
|
|
- averagers = [
|
|
|
- (PeerID(key), looking_for_group.expiration_time)
|
|
|
- for key, looking_for_group in result.value.items()
|
|
|
- if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
|
|
|
- ]
|
|
|
- num_active_averagers = sum(
|
|
|
- 1
|
|
|
- for key, looking_for_group in result.value.items()
|
|
|
- if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
|
|
|
- )
|
|
|
-
|
|
|
- suggested_nbits = self.get_suggested_nbits(result)
|
|
|
- if (
|
|
|
- suggested_nbits is not None
|
|
|
- and suggested_nbits != len(self.group_bits)
|
|
|
- and suggested_nbits != self.suggested_nbits
|
|
|
- ):
|
|
|
- self.suggested_nbits = suggested_nbits
|
|
|
- logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
|
|
|
- elif num_active_averagers >= self.excessive_size:
|
|
|
- self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
|
|
|
- logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
|
|
|
+ averagers = []
|
|
|
+ for key, looking_for_group in result.value.items():
|
|
|
+ try:
|
|
|
+ if only_active and not looking_for_group.value:
|
|
|
+ continue
|
|
|
+ averagers.append((PeerID(key), looking_for_group.expiration_time))
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
|
|
|
return averagers
|
|
|
|
|
|
- async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
|
|
|
- """notify other peers that they can run averaging at this depth"""
|
|
|
- return await self.dht.store(
|
|
|
- key=group_key,
|
|
|
- subkey=self.RESERVED_KEY_FOR_NBITS,
|
|
|
- value=nbits,
|
|
|
- expiration_time=expiration_time,
|
|
|
- return_future=True,
|
|
|
- )
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
|
|
|
- if (
|
|
|
- isinstance(search_result, ValueWithExpiration)
|
|
|
- and cls.RESERVED_KEY_FOR_NBITS in search_result.value
|
|
|
- and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
|
|
|
- ):
|
|
|
- return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
|
|
|
- else:
|
|
|
- return None
|
|
|
-
|
|
|
async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
|
|
|
"""this function is triggered every time an averager finds an allreduce group"""
|
|
|
rng = random.Random(group_info.group_id)
|
|
@@ -148,37 +98,6 @@ class GroupKeyManager:
|
|
|
self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
|
|
|
logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
|
|
|
|
|
|
- if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
|
|
|
- asyncio.create_task(self.notify_stragglers())
|
|
|
- if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
|
|
|
- num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
|
|
|
- self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
|
|
|
- self.group_bits = self.group_bits[-self.suggested_nbits :]
|
|
|
- self.suggested_nbits = None
|
|
|
-
|
|
|
async def update_key_on_not_enough_peers(self):
|
|
|
"""this function is triggered whenever averager fails to assemble group within timeout"""
|
|
|
- new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
|
|
|
- prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
|
|
|
- if self.group_bits != prev_nbits:
|
|
|
- logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
|
|
|
- self.suggested_nbits = None
|
|
|
-
|
|
|
- async def notify_stragglers(self):
|
|
|
- """Find averagers that have fewer nbits and redirect them to your current nbits"""
|
|
|
- for nbits in reversed(range(1, len(self.group_bits) - 1)):
|
|
|
- preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
|
|
|
- preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
|
|
|
-
|
|
|
- if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
|
|
|
- await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
|
|
|
- break
|
|
|
-
|
|
|
- root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
|
|
|
- if (
|
|
|
- isinstance(root_data, dict)
|
|
|
- and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
|
|
|
- > get_dht_time() + self.nbits_grace_period
|
|
|
- ):
|
|
|
- return
|
|
|
- await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
|
|
|
+ pass # to be implemented in subclasses
|