Quellcode durchsuchen

Negative cache for DHT (#114)

* docstring: enable negative caching

* implement negative caching

* add test for negative caching

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic vor 4 Jahren
Ursprung
Commit
8e068a8341
4 geänderte Dateien mit 72 neuen und 7 gelöschten Zeilen
  1. 1 1
      hivemind/__init__.py
  2. 29 3
      hivemind/dht/__init__.py
  3. 9 3
      hivemind/server/__init__.py
  4. 33 0
      tests/test_dht_experts.py

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.8'
+__version__ = '0.8.9'

+ 29 - 3
hivemind/dht/__init__.py

@@ -71,6 +71,22 @@ class DHT(mp.Process):
         (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
+    :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
+      result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
+
+      1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
+         a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
+         result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
+      2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
+         search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
+         experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
+         visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
+      3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
+         performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
+         indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
+         be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
+         Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
+
     :param kwargs: any other params will be forwarded to DHTNode upon creation
 
     Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
@@ -102,11 +118,11 @@ class DHT(mp.Process):
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 receiver_threads: int = 1, expiration: float = 300, **kwargs):
+                 receiver_threads: int = 1, negative_caching: bool = True, expiration: float = 300, **kwargs):
         super().__init__()
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
-        self.expiration = expiration
+        self.expiration, self.negative_caching = expiration, negative_caching
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
@@ -262,7 +278,13 @@ class DHT(mp.Process):
                     successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
                                   if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
                                   and len(match.value) == 2}
-                    beam.append((scores[pending_best_index], pending_best_prefix, successors))
+                    if successors:
+                        beam.append((scores[pending_best_index], pending_best_prefix, successors))
+                elif maybe_prefix_data is None and self.negative_caching:
+                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
+                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
+                                                   expiration_time=get_dht_time() + self.expiration))
+
             except asyncio.CancelledError:
                 for _, pending_task in pending_tasks:
                     pending_task.cancel()
@@ -304,6 +326,10 @@ class DHT(mp.Process):
                                       and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
             else:
                 successors[prefix] = {}
+                if found is None and self.negative_caching:
+                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
+                    asyncio.create_task(node.store(prefix, subkey=-1, value=None,
+                                                   expiration_time=get_dht_time() + self.expiration))
         if future:
             future.set_result(successors)
         return successors

+ 9 - 3
hivemind/server/__init__.py

@@ -76,7 +76,7 @@ class Server(threading.Thread):
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-         means "sample random experts between myprefix.0.0 and myprefix.255.255;
+           means "sample random experts between myprefix.0.0 and myprefix.255.255;
         :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
         :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
         :param hidden_dim: main dimension for expert_cls
@@ -86,10 +86,16 @@ class Server(threading.Thread):
         :param optim_cls: uses this optimizer to train all experts
         :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: a list of peers that will introduce this node to the dht,\
-         e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
+           e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
+
         :param dht_port:  DHT node will listen on this port, default = find open port
-        You can then use this node as initial peer for subsequent servers.
+           You can then use this node as initial peer for subsequent servers.
+
         :param verbose: whether to print server started / finished / terminated events
+        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
+            hosted on this server. For a more fine-grained compression, start server in python and specify compression
+            for each BatchTensorProto in ExpertBackend for the respective experts.
+
         :param start: if True, starts server right away and returns when server is ready for requests
         """
         if verbose and len(kwargs) != 0:

+ 33 - 0
tests/test_dht_experts.py

@@ -1,6 +1,8 @@
 import random
 import numpy as np
 import pytest
+import asyncio
+import multiprocessing as mp
 
 import hivemind
 from hivemind import LOCALHOST, UidEndpoint
@@ -128,3 +130,34 @@ def test_uid_patterns():
         assert not hivemind.is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
     for pfx in invalid_prefixes:
         assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
+
+
+def test_negative_caching():
+    test_success = mp.Event()
+    peers = []
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, negative_caching=False, cache_locally=False, start=True))
+
+    normal_peer, writer_peer = random.sample(peers, 2)
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+    neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, negative_caching=True, cache_locally=False, start=True)
+
+    assert all(writer_peer.declare_experts(['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
+    # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
+    assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
+
+    async def _tester():
+        node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
+        fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
+        for i in range(6):
+            assert fetched[i] is not None, f"node should have cached ffn.{i}."
+        for i in range(6, len(fetched)):
+            assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
+        test_success.set()
+
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()