浏览代码

Update key_manager.py

Michael Diskin 4 年之前
父节点
当前提交
dc0ea7580e
共有 1 个文件被更改,包括 50 次插入56 次删除
  1. 50 56
      hivemind/client/averaging/key_manager.py

+ 50 - 56
hivemind/client/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import re
 import re
 import random
 import random
 from typing import Optional, List, Tuple
 from typing import Optional, List, Tuple
@@ -7,11 +6,12 @@ import numpy as np
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.client.averaging.allreduce import AllReduceRunner
 from hivemind.client.averaging.allreduce import AllReduceRunner
-from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time
 
 
 GroupKey = str
 GroupKey = str
 GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$')  # e.g. bert_exp4_averaging.0b01001101
 GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$')  # e.g. bert_exp4_averaging.0b01001101
 logger = get_logger(__name__)
 logger = get_logger(__name__)
+SUCCESS_KEY = 'abyrvalg'
 
 
 
 
 def is_valid_group(maybe_group: str) -> bool:
 def is_valid_group(maybe_group: str) -> bool:
@@ -30,15 +30,19 @@ class GroupKeyManager:
                  nbits_expiration: float = 60, nbits_rewrite_grace_period: float = 15):
                  nbits_expiration: float = 60, nbits_rewrite_grace_period: float = 15):
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         if initial_group_bits is None:
         if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
+            initial_group_nbits = 0  # TODO
             initial_group_bits = ''.join(random.choice('01') for _ in range(initial_group_nbits))
             initial_group_bits = ''.join(random.choice('01') for _ in range(initial_group_nbits))
         self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
         self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
+        self.excessive_size = excessive_size or round(target_group_size * 1.1)
         self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
         self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
         self.suggested_nbits: Optional[int] = None
         self.suggested_nbits: Optional[int] = None
+        self.num_active_averagers = 0
+        self.success_upper = 0
+        self.success_lower = 0
+        self.small_group = 0
+        self.crowded = 0
 
 
     @property
     @property
     def current_key(self) -> GroupKey:
     def current_key(self) -> GroupKey:
@@ -78,33 +82,14 @@ class GroupKeyManager:
             return []
             return []
         averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
         averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
                      if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
                      if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
-        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if suggested_nbits is not None and suggested_nbits != len(self.group_bits) and \
-                suggested_nbits != self.suggested_nbits:
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
-        return averagers
-
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """ notify other peers that they can run averaging at this depth """
-        return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
-                                    expiration_time=expiration_time, return_future=True)
+        self.num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
 
 
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
-                and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
+        return averagers
 
 
     async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
     async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
         """ this function is triggered every time an averager finds an allreduce group """
         """ this function is triggered every time an averager finds an allreduce group """
+
+        # IMPORTANT LOGIC OF MOSHPIT SGD
         rng = random.Random(allreduce_group.group_key_seed)
         rng = random.Random(allreduce_group.group_key_seed)
         index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
         index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
         generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
         generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
@@ -112,35 +97,44 @@ class GroupKeyManager:
         new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
         new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else ''
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else ''
         logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
         logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+        # /IMPORTANT LOGIC OF MOSHPIT SGD
+
+        c = await self.dht.get(key=SUCCESS_KEY, latest=True, return_future=True)
+        if c is not None:
+            d = c.value
+            if len(self.group_bits) - 1 in d:
+                self.success_lower += 1
+            else:
+                self.success_lower = 0
+            if len(self.group_bits) + 1 in d:
+                self.success_upper += 1
+            else:
+                self.success_upper = 0
+
+        else:
+            self.success_upper = 0
+            self.success_lower = 0
+
+        if is_leader and self.target_group_size//2 <= self.num_active_averagers <= self.target_group_size:
+            await self.dht.store(key=SUCCESS_KEY, subkey=len(self.group_bits), value=True,
+                                 expiration_time=get_dht_time() + 1, return_future=True, )
+        if self.num_active_averagers > self.target_group_size:
+            self.crowded += 1
+        else:
+            self.crowded = 0
+
+        if self.num_active_averagers <= self.target_group_size//2:
+            self.small_group += 1
+            if self.small_group > 1 and len(self.group_bits):
+                self.group_bits = self.group_bits[1:]
+                return
+        else:
+            self.small_group = 0
+
+        if self.success_upper > 2 or self.crowded > 1:
+            self.group_bits = random.choice('01') + self.group_bits
+            return
 
 
-        if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits:]
-        self.suggested_nbits = None
 
 
     async def update_key_on_not_enough_peers(self):
     async def update_key_on_not_enough_peers(self):
         """ this function is triggered whenever averager fails to assemble group within timeout """
         """ this function is triggered whenever averager fails to assemble group within timeout """
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ''
-        if self.group_bits != prev_nbits:
-            logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """ Find averagers that have fewer nbits and redirect them to your current nbits """
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if isinstance(root_data, dict) and root_data.get(
-                self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period:
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)