Explorar o código

Change expiration time in declare_experts, fix update_period discrepancy (#482)

* hivemind-server will now set expiration dynamically based on update_period
* default update_period is now 30s everywhere. Was 5s in DHTHandlerThread but overriden by 30s from Server.__init__

As @ial32 suggested on discord, the current hard-coded 300s expiration makes it hard to debug and may cause problems if update_period >= 300.


Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
a49ab1910e

+ 4 - 0
hivemind/hivemind_cli/run_server.py

@@ -50,6 +50,10 @@ def main():
                         help='LR scheduler type to use')
     parser.add_argument('--num_warmup_steps', type=int, required=False,
                         help='The number of warmup steps for LR schedule')
+    parser.add_argument('--update_period', type=float, required=False, default=30,
+                        help='Server will report experts to DHT once in this many seconds')
+    parser.add_argument('--expiration', type=float, required=False, default=None,
+                        help='DHT entries will expire after this many seconds')
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 

+ 1 - 1
hivemind/moe/server/checkpoints.py

@@ -34,7 +34,7 @@ def copy_tree(src: str, dst: str):
 
 
 class CheckpointSaver(threading.Thread):
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
+    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float):
         super().__init__()
         assert is_directory(checkpoint_dir)
         self.expert_backends = expert_backends

+ 13 - 9
hivemind/moe/server/dht_handler.py

@@ -16,32 +16,35 @@ from hivemind.moe.expert_uid import (
     split_uid,
 )
 from hivemind.p2p import PeerID
-from hivemind.utils import MPFuture, get_dht_time
+from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_time
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs):
         super().__init__(**kwargs)
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
+        self.expiration = expiration
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys())
+        declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys())
+            declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
 
 
 def declare_experts(
-    dht: DHT, uids: Sequence[ExpertUID], expiration: DHTExpiration = 300, wait: bool = True
+    dht: DHT, uids: Sequence[ExpertUID], expiration_time: DHTExpiration, wait: bool = True
 ) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
 
     :param uids: a list of expert ids to update
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param expiration: experts will be visible for this many seconds
+    :param expiration_time: experts will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
@@ -49,14 +52,15 @@ def declare_experts(
         uids = list(uids)
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-    return dht.run_coroutine(partial(_declare_experts, uids=uids, expiration=expiration), return_future=not wait)
+    return dht.run_coroutine(
+        partial(_declare_experts, uids=uids, expiration_time=expiration_time), return_future=not wait
+    )
 
 
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     peer_id_base58 = dht.peer_id.to_base58()
 

+ 8 - 1
hivemind/moe/server/server.py

@@ -46,6 +46,7 @@ class Server(threading.Thread):
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
         if dht is None, this parameter is ignored.
+    :param expiration: when server declares its experts to the DHT, these entries will expire after this many seconds
     :param start: if True, the server will immediately start as a background thread and returns control after server
         is ready (see .ready below)
     """
@@ -55,7 +56,8 @@ class Server(threading.Thread):
         dht: DHT,
         expert_backends: Dict[str, ExpertBackend],
         num_connection_handlers: int = 1,
-        update_period: int = 30,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
         start=False,
         checkpoint_dir=None,
         **kwargs,
@@ -75,6 +77,7 @@ class Server(threading.Thread):
                 experts=self.experts,
                 dht=self.dht,
                 update_period=self.update_period,
+                expiration=expiration,
                 daemon=True,
             )
 
@@ -103,6 +106,8 @@ class Server(threading.Thread):
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
         *,
         start: bool,
         **kwargs,
@@ -213,6 +218,8 @@ class Server(threading.Thread):
             device=device,
             checkpoint_dir=checkpoint_dir,
             stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
             start=start,
         )
 

+ 12 - 14
tests/test_dht_experts.py

@@ -6,10 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-from hivemind.dht import DHTNode
+from hivemind import get_dht_time
+from hivemind.dht.node import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid
-from hivemind.moe.server import declare_experts, get_experts
+from hivemind.moe.server.dht_handler import declare_experts, get_experts
 
 
 @pytest.mark.forked
@@ -24,14 +25,14 @@ def test_store_get_experts(n_peers=10):
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size])
+        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], get_dht_time() + 30)
 
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
     other_expert = "my_other_expert.1337"
-    declare_experts(other_peer, [other_expert])
+    declare_experts(other_peer, [other_expert], get_dht_time() + 30)
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert first_found.peer_id == other_peer.peer_id
@@ -43,7 +44,7 @@ def test_store_get_experts(n_peers=10):
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
+    assert all(declare_experts(remaining_peer1, ["new_expert.1"], expiration_time=get_dht_time() + 30))
     assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id
 
 
@@ -60,10 +61,7 @@ def test_beam_search(
     )
     for batch_start in range(0, len(real_experts), batch_size):
         dht = random.choice(dht_instances)
-        declare_experts(
-            dht,
-            real_experts[batch_start : batch_start + batch_size],
-        )
+        declare_experts(dht, real_experts[batch_start : batch_start + batch_size], get_dht_time() + 30)
 
     neighbors = sum(
         [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
@@ -90,14 +88,14 @@ def test_dht_single_node():
     node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
 
-    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"]).values())
-    assert len(declare_experts(node, ["ffn.1", "ffn.2"])) == 4
-    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
+    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], get_dht_time() + 30).values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"], get_dht_time() + 30)) == 4
+    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], get_dht_time() + 30)) == 7
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
         assert expert.peer_id == node.peer_id
 
-    assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
+    assert all(declare_experts(node, ["expert.5", "expert.2"], get_dht_time() + 30).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
 
@@ -196,7 +194,7 @@ async def test_negative_caching(n_peers=10):
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
     writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"]).values())
+    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], get_dht_time() + 30).values())
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)

+ 2 - 2
tests/test_moe.py

@@ -10,7 +10,7 @@ from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
+from hivemind.utils import BatchTensorDescriptor, get_dht_time
 
 
 @pytest.mark.forked
@@ -163,7 +163,7 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids))
+    assert all(declare_experts(dht, all_expert_uids, expiration_time=get_dht_time() + 30))
 
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")