|
@@ -1,4 +1,3 @@
|
|
|
-import asyncio
|
|
|
import re
|
|
|
import random
|
|
|
from typing import Optional, List, Tuple
|
|
@@ -7,11 +6,12 @@ import numpy as np
|
|
|
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.client.averaging.allreduce import AllReduceRunner
|
|
|
-from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
|
|
|
+from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time
|
|
|
|
|
|
GroupKey = str
|
|
|
GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$') # e.g. bert_exp4_averaging.0b01001101
|
|
|
logger = get_logger(__name__)
|
|
|
+SUCCESS_KEY = 'abyrvalg'
|
|
|
|
|
|
|
|
|
def is_valid_group(maybe_group: str) -> bool:
|
|
@@ -30,15 +30,19 @@ class GroupKeyManager:
|
|
|
nbits_expiration: float = 60, nbits_rewrite_grace_period: float = 15):
|
|
|
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
|
|
|
if initial_group_bits is None:
|
|
|
- search_result = dht.get(f"{prefix}.0b", latest=True)
|
|
|
- initial_group_nbits = self.get_suggested_nbits(search_result) or 0
|
|
|
+ initial_group_nbits = 0 # TODO
|
|
|
initial_group_bits = ''.join(random.choice('01') for _ in range(initial_group_nbits))
|
|
|
self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
|
|
|
self.target_group_size = target_group_size
|
|
|
self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
|
|
|
- self.excessive_size = excessive_size or target_group_size * 3
|
|
|
+ self.excessive_size = excessive_size or round(target_group_size * 1.1)
|
|
|
self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
|
|
|
self.suggested_nbits: Optional[int] = None
|
|
|
+ self.num_active_averagers = 0
|
|
|
+ self.success_upper = 0
|
|
|
+ self.success_lower = 0
|
|
|
+ self.small_group = 0
|
|
|
+ self.crowded = 0
|
|
|
|
|
|
@property
|
|
|
def current_key(self) -> GroupKey:
|
|
@@ -78,33 +82,14 @@ class GroupKeyManager:
|
|
|
return []
|
|
|
averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
|
|
|
if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
|
|
|
- num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
|
|
|
-
|
|
|
- suggested_nbits = self.get_suggested_nbits(result)
|
|
|
- if suggested_nbits is not None and suggested_nbits != len(self.group_bits) and \
|
|
|
- suggested_nbits != self.suggested_nbits:
|
|
|
- self.suggested_nbits = suggested_nbits
|
|
|
- logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
|
|
|
- elif num_active_averagers >= self.excessive_size:
|
|
|
- self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
|
|
|
- logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
|
|
|
- return averagers
|
|
|
-
|
|
|
- async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
|
|
|
- """ notify other peers that they can run averaging at this depth """
|
|
|
- return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
|
|
|
- expiration_time=expiration_time, return_future=True)
|
|
|
+ self.num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
|
|
|
|
|
|
- @classmethod
|
|
|
- def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
|
|
|
- if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
|
|
|
- and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
|
|
|
- return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
|
|
|
- else:
|
|
|
- return None
|
|
|
+ return averagers
|
|
|
|
|
|
async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
|
|
|
""" this function is triggered every time an averager finds an allreduce group """
|
|
|
+
|
|
|
+ # IMPORTANT LOGIC OF MOSHPIT SGD
|
|
|
rng = random.Random(allreduce_group.group_key_seed)
|
|
|
index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
|
|
|
generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
|
|
@@ -112,35 +97,44 @@ class GroupKeyManager:
|
|
|
new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
|
|
|
self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else ''
|
|
|
logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
|
|
|
+ # /IMPORTANT LOGIC OF MOSHPIT SGD
|
|
|
+
|
|
|
+ c = await self.dht.get(key=SUCCESS_KEY, latest=True, return_future=True)
|
|
|
+ if c is not None:
|
|
|
+ d = c.value
|
|
|
+ if len(self.group_bits) - 1 in d:
|
|
|
+ self.success_lower += 1
|
|
|
+ else:
|
|
|
+ self.success_lower = 0
|
|
|
+ if len(self.group_bits) + 1 in d:
|
|
|
+ self.success_upper += 1
|
|
|
+ else:
|
|
|
+ self.success_upper = 0
|
|
|
+
|
|
|
+ else:
|
|
|
+ self.success_upper = 0
|
|
|
+ self.success_lower = 0
|
|
|
+
|
|
|
+ if is_leader and self.target_group_size//2 <= self.num_active_averagers <= self.target_group_size:
|
|
|
+ await self.dht.store(key=SUCCESS_KEY, subkey=len(self.group_bits), value=True,
|
|
|
+ expiration_time=get_dht_time() + 1, return_future=True, )
|
|
|
+ if self.num_active_averagers > self.target_group_size:
|
|
|
+ self.crowded += 1
|
|
|
+ else:
|
|
|
+ self.crowded = 0
|
|
|
+
|
|
|
+ if self.num_active_averagers <= self.target_group_size//2:
|
|
|
+ self.small_group += 1
|
|
|
+ if self.small_group > 1 and len(self.group_bits):
|
|
|
+ self.group_bits = self.group_bits[1:]
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ self.small_group = 0
|
|
|
+
|
|
|
+ if self.success_upper > 2 or self.crowded > 1:
|
|
|
+ self.group_bits = random.choice('01') + self.group_bits
|
|
|
+ return
|
|
|
|
|
|
- if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
|
|
|
- asyncio.create_task(self.notify_stragglers())
|
|
|
- if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
|
|
|
- num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
|
|
|
- self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
|
|
|
- self.group_bits = self.group_bits[-self.suggested_nbits:]
|
|
|
- self.suggested_nbits = None
|
|
|
|
|
|
async def update_key_on_not_enough_peers(self):
|
|
|
""" this function is triggered whenever averager fails to assemble group within timeout """
|
|
|
- new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
|
|
|
- prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ''
|
|
|
- if self.group_bits != prev_nbits:
|
|
|
- logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
|
|
|
- self.suggested_nbits = None
|
|
|
-
|
|
|
- async def notify_stragglers(self):
|
|
|
- """ Find averagers that have fewer nbits and redirect them to your current nbits """
|
|
|
- for nbits in reversed(range(1, len(self.group_bits) - 1)):
|
|
|
- preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
|
|
|
- preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
|
|
|
-
|
|
|
- if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
|
|
|
- await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
|
|
|
- break
|
|
|
-
|
|
|
- root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
|
|
|
- if isinstance(root_data, dict) and root_data.get(
|
|
|
- self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period:
|
|
|
- return
|
|
|
- await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
|