Bläddra i källkod

Make target group size optional (#412)

This pull-request makes target_group_size optional, to avoid bothering user with setting it.
The rationale here is that in most cases we ended up using `target_group_size=some_very_large_number_please_be_large_enough`.
Therefore, it would make sense to make this a default choice to avoid concerning users with the extra low-level kwarg.


Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 år sedan
förälder
incheckning
22665fdcee

+ 1 - 8
hivemind/averaging/averager.py

@@ -112,7 +112,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         *,
         start: bool,
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int] = None,
         min_group_size: int = 2,
         initial_group_bits: str = "",
         averaging_expiration: Optional[float] = None,
@@ -137,8 +137,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert bandwidth is None or (
             bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
         ), "bandwidth must be a non-negative float32"
-        if not is_power_of_two(target_group_size):
-            logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
@@ -697,11 +695,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 future.set_exception(e)
 
 
-def is_power_of_two(n):
-    """Check whether n is a power of 2"""
-    return (n != 0) and (n & (n - 1) == 0)
-
-
 def _background_thread_fetch_current_state(
     serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
 ):

+ 20 - 3
hivemind/averaging/key_manager.py

@@ -11,6 +11,7 @@ from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
+DEFAULT_NUM_BUCKETS = 256
 logger = get_logger(__name__)
 
 
@@ -29,9 +30,12 @@ class GroupKeyManager:
         dht: DHT,
         prefix: str,
         initial_group_bits: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
     ):
         assert all(bit in "01" for bit in initial_group_bits)
+        if target_group_size is not None and not is_power_of_two(target_group_size):
+            logger.warning("It is recommended to set target_group_size to a power of 2.")
+
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.peer_id = dht.peer_id
@@ -92,8 +96,11 @@ class GroupKeyManager:
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         index = group_info.peer_ids.index(self.peer_id)
-        generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
-        nbits = int(np.ceil(np.log2(self.target_group_size)))
+        num_buckets = self.target_group_size
+        if num_buckets is None:
+            num_buckets = next_power_of_two(group_info.group_size)
+        generalized_index = rng.sample(range(num_buckets), group_info.group_size)[index]
+        nbits = int(np.ceil(np.log2(num_buckets)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
@@ -101,3 +108,13 @@ class GroupKeyManager:
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
         pass  # to be implemented in subclasses
+
+
+def is_power_of_two(n):
+    """Check whether n is a power of 2"""
+    return (n != 0) and (n & (n - 1) == 0)
+
+
+def next_power_of_two(n):
+    """Round n up to the nearest power of 2"""
+    return 1 if n == 0 else 2 ** (n - 1).bit_length()

+ 7 - 3
hivemind/averaging/matchmaking.py

@@ -44,7 +44,7 @@ class Matchmaking:
         *,
         servicer_type: Type[ServicerBase],
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
         min_group_size: int,
         min_matchmaking_time: float,
         request_timeout: float,
@@ -267,7 +267,11 @@ class Matchmaking:
                 self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
-                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
+                if (
+                    self.target_group_size is not None
+                    and len(self.current_followers) + 1 >= self.target_group_size
+                    and not self.assembled_group.done()
+                ):
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     await self.leader_assemble_group()
 
@@ -353,7 +357,7 @@ class Matchmaking:
             )
         elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
-        elif len(self.current_followers) + 1 >= self.target_group_size:
+        elif self.target_group_size is not None and len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
             return None

+ 0 - 1
tests/test_averaging.py

@@ -428,7 +428,6 @@ def test_averaging_trigger():
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
-            target_group_size=4,
             min_matchmaking_time=0.5,
             request_timeout=0.3,
             prefix="mygroup",