瀏覽代碼

Disable elasticity for averaging, add error handling (#362)

* disable old elasticity code
* change DHT code in such a way that can recover from file/memory errors
* reduce the size of shmem buffer based on our performance benchmarks

Currently, dynamic averager nbits is broken and we deliberately parameterize training in such a way that it will not be triggered. This pull-request removes the broken code.

In future, we can revive elasticity based on @yhn112 's earlier attempt at fixing it: dc0ea75

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 年之前
父節點
當前提交
0e90183951

+ 8 - 3
hivemind/averaging/averager.py

@@ -96,7 +96,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         prefix: str,
         target_group_size: int,
         min_group_size: int = 2,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         request_timeout: float = 3,
         averaging_alpha: float = 1.0,
@@ -117,7 +117,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
+        assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
@@ -241,7 +241,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 self._ready.set_result(None)
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._matchmaking.request_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                         await task

+ 12 - 93
hivemind/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import re
 from typing import List, Optional, Tuple
@@ -25,31 +24,17 @@ class GroupKeyManager:
     Utility class that declares and fetches averaging-related keys using a DHT
     """
 
-    RESERVED_KEY_FOR_NBITS = "::NBITS"
-
     def __init__(
         self,
         dht: DHT,
         prefix: str,
-        initial_group_bits: Optional[str],
+        initial_group_bits: str,
         target_group_size: int,
-        insufficient_size: Optional[int] = None,
-        excessive_size: Optional[int] = None,
-        nbits_expiration: float = 60,
-        nbits_rewrite_grace_period: float = 15,
     ):
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
+        assert all(bit in "01" for bit in initial_group_bits)
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
-        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
-        self.suggested_nbits: Optional[int] = None
+        self.peer_id = dht.peer_id
 
     @property
     def current_key(self) -> GroupKey:
@@ -93,51 +78,16 @@ class GroupKeyManager:
         if result is None or not isinstance(result.value, dict):
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
-        averagers = [
-            (PeerID(key), looking_for_group.expiration_time)
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
-        ]
-        num_active_averagers = sum(
-            1
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
-        )
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if (
-            suggested_nbits is not None
-            and suggested_nbits != len(self.group_bits)
-            and suggested_nbits != self.suggested_nbits
-        ):
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        averagers = []
+        for key, looking_for_group in result.value.items():
+            try:
+                if only_active and not looking_for_group.value:
+                    continue
+                averagers.append((PeerID(key), looking_for_group.expiration_time))
+            except Exception as e:
+                logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
         return averagers
 
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """notify other peers that they can run averaging at this depth"""
-        return await self.dht.store(
-            key=group_key,
-            subkey=self.RESERVED_KEY_FOR_NBITS,
-            value=nbits,
-            expiration_time=expiration_time,
-            return_future=True,
-        )
-
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if (
-            isinstance(search_result, ValueWithExpiration)
-            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
-            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
-        ):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
-
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
@@ -148,37 +98,6 @@ class GroupKeyManager:
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
-        if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits :]
-        self.suggested_nbits = None
-
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
-        if self.group_bits != prev_nbits:
-            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """Find averagers that have fewer nbits and redirect them to your current nbits"""
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if (
-            isinstance(root_data, dict)
-            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
-            > get_dht_time() + self.nbits_grace_period
-        ):
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        pass  # to be implemented in subclasses

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -45,7 +45,7 @@ class Matchmaking:
         min_group_size: int,
         request_timeout: float,
         client_mode: bool,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
     ):
         assert "." not in prefix, "group prefix must be a string without ."

+ 6 - 1
hivemind/dht/__init__.py

@@ -128,7 +128,12 @@ class DHT(mp.Process):
                 self._ready.set_result(None)
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._node.protocol.wait_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                         await task

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -53,7 +53,7 @@ class SharedBytes:
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         with cls._lock:
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
-                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
                 cls._pid = os.getpid()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._index = 0

+ 5 - 2
tests/test_averaging.py

@@ -10,6 +10,7 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 
@@ -363,9 +364,11 @@ def test_too_few_peers():
         )
         for i, dht in enumerate(dht_instances)
     ]
-    step_futures = [averager.step(wait=False) for averager in averagers]
+    step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
+
     for future in step_futures:
-        assert len(future.result()) == 2
+        with pytest.raises(AllreduceException):
+            future.result()
 
     for process in averagers + dht_instances:
         process.shutdown()