|
@@ -11,6 +11,7 @@ from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get
|
|
|
|
|
|
GroupKey = str
|
|
|
GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$") # e.g. bert_exp4_averaging.0b01001101
|
|
|
+DEFAULT_NUM_BUCKETS = 256
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
@@ -29,9 +30,12 @@ class GroupKeyManager:
|
|
|
dht: DHT,
|
|
|
prefix: str,
|
|
|
initial_group_bits: str,
|
|
|
- target_group_size: int,
|
|
|
+ target_group_size: Optional[int],
|
|
|
):
|
|
|
assert all(bit in "01" for bit in initial_group_bits)
|
|
|
+ if target_group_size is not None and not is_power_of_two(target_group_size):
|
|
|
+ logger.warning("It is recommended to set target_group_size to a power of 2.")
|
|
|
+
|
|
|
self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
|
|
|
self.target_group_size = target_group_size
|
|
|
self.peer_id = dht.peer_id
|
|
@@ -92,8 +96,11 @@ class GroupKeyManager:
|
|
|
"""this function is triggered every time an averager finds an allreduce group"""
|
|
|
rng = random.Random(group_info.group_id)
|
|
|
index = group_info.peer_ids.index(self.peer_id)
|
|
|
- generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
|
|
|
- nbits = int(np.ceil(np.log2(self.target_group_size)))
|
|
|
+ num_buckets = self.target_group_size
|
|
|
+ if num_buckets is None:
|
|
|
+ num_buckets = next_power_of_two(group_info.group_size)
|
|
|
+ generalized_index = rng.sample(range(num_buckets), group_info.group_size)[index]
|
|
|
+ nbits = int(np.ceil(np.log2(num_buckets)))
|
|
|
new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
|
|
|
self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
|
|
|
logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
|
|
@@ -101,3 +108,13 @@ class GroupKeyManager:
|
|
|
async def update_key_on_not_enough_peers(self):
|
|
|
"""this function is triggered whenever averager fails to assemble group within timeout"""
|
|
|
pass # to be implemented in subclasses
|
|
|
+
|
|
|
+
|
|
|
+def is_power_of_two(n):
|
|
|
+ """Check whether n is a power of 2"""
|
|
|
+ return (n != 0) and (n & (n - 1) == 0)
|
|
|
+
|
|
|
+
|
|
|
+def next_power_of_two(n):
|
|
|
+ """Round n up to the nearest power of 2"""
|
|
|
+ return 1 if n == 0 else 2 ** (n - 1).bit_length()
|