Преглед на файлове

Implement a CLI for hivemind.DHT (#465)

* Implement a CLI for hivemind.DHT

* Fix log message in README

* Update examples/albert/README.md

* Add a basic test for hivemind-dht

* Move log_visible_maddrs to hivemind.utils.networking

Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Max Ryabinin преди 3 години
родител
ревизия
c49802aabd

+ 2 - 2
examples/albert/README.md

@@ -27,8 +27,8 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 ```
 $ ./run_training_monitor.py --wandb_project Demo-run
-Oct 14 16:26:36.083 [INFO] Running a DHT peer. To connect other peers to this one over the Internet,
-use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
+Oct 14 16:26:36.083 [INFO] Running a DHT instance. To connect other peers to this one, use
+ --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 Oct 14 16:26:36.083 [INFO] Full list of visible multiaddresses: ...
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32

+ 2 - 1
examples/albert/run_trainer.py

@@ -19,6 +19,7 @@ from transformers.trainer_utils import is_main_process
 
 from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
 
 import utils
 from arguments import (
@@ -227,7 +228,7 @@ def main():
         announce_maddrs=collaboration_args.announce_maddrs,
         identity_path=collaboration_args.identity_path,
     )
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     if torch.cuda.device_count() != 0:

+ 2 - 1
examples/albert/run_training_monitor.py

@@ -14,6 +14,7 @@ from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, g
 import hivemind
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
 
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
@@ -168,7 +169,7 @@ if __name__ == "__main__":
         announce_maddrs=monitor_args.announce_maddrs,
         identity_path=monitor_args.identity_path,
     )
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 
     if monitor_args.wandb_project is not None:
         wandb.init(project=monitor_args.wandb_project)

+ 1 - 23
examples/albert/utils.py

@@ -1,13 +1,11 @@
 from typing import Dict, List, Tuple
 
-from multiaddr import Multiaddr
 from pydantic import BaseModel, StrictFloat, confloat, conint
 
-from hivemind import choose_ip_address
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
-from hivemind.utils.logging import TextStyle, get_logger
+from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
@@ -28,23 +26,3 @@ def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
     signature_validator = RSASignatureValidator()
     validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
     return validators, signature_validator.local_public_key
-
-
-def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
-    if only_p2p:
-        unique_addrs = {addr["p2p"] for addr in visible_maddrs}
-        initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
-    else:
-        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
-        if available_ips:
-            preferred_ip = choose_ip_address(available_ips)
-            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
-        else:
-            selected_maddrs = visible_maddrs
-        initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
-
-    logger.info(
-        f"Running a DHT peer. To connect other peers to this one over the Internet, use "
-        f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
-    )
-    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")

+ 1 - 1
hivemind/dht/dht.py

@@ -244,7 +244,7 @@ class DHT(mp.Process):
           DHT fields made by this coroutine will not be accessible from the host process.
         :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
-        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        :note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
         """
         future = MPFuture()
         self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))

+ 2 - 2
hivemind/dht/node.py

@@ -146,7 +146,7 @@ class DHTNode:
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
-          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
+          nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
         :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
           if they are accessed this many seconds before expiration time.
@@ -341,7 +341,7 @@ class DHTNode:
     ) -> bool:
         """
         Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
-        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
+        :note: store is a simplified interface to store_many, all kwargs are forwarded there
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)

+ 76 - 0
hivemind/hivemind_cli/run_dht.py

@@ -0,0 +1,76 @@
+import time
+from argparse import ArgumentParser
+
+from hivemind.dht import DHT, DHTNode
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+async def report_status(dht: DHT, node: DHTNode):
+    logger.info(
+        f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
+        f"are in the local routing table "
+    )
+    logger.debug(f"Routing table contents: {node.protocol.routing_table}")
+    logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
+    logger.debug(f"Local storage contents: {node.protocol.storage}")
+
+
+def main():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--initial_peers",
+        nargs="*",
+        help="Multiaddrs of the peers that will welcome you into the existing DHT. "
+        "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
+    )
+    parser.add_argument(
+        "--host_maddrs",
+        nargs="*",
+        default=["/ip4/0.0.0.0/tcp/0"],
+        help="Multiaddrs to listen for external connections from other DHT instances. "
+        "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
+    )
+    parser.add_argument(
+        "--announce_maddrs",
+        nargs="*",
+        help="Visible multiaddrs the host announces for external connections from other DHT instances",
+    )
+    parser.add_argument(
+        "--use_ipfs",
+        action="store_true",
+        help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
+        "part of the multiaddrs for the initial_peers "
+        "(no need to specify a particular IPv4/IPv6 host and port)",
+    )
+    parser.add_argument(
+        "--identity_path",
+        help="Path to a private key file. If defined, makes the peer ID deterministic. "
+        "If the file does not exist, writes a new private key to this file.",
+    )
+    parser.add_argument(
+        "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
+    )
+
+    args = parser.parse_args()
+
+    dht = DHT(
+        start=True,
+        initial_peers=args.initial_peers,
+        host_maddrs=args.host_maddrs,
+        announce_maddrs=args.announce_maddrs,
+        use_ipfs=args.use_ipfs,
+        identity_path=args.identity_path,
+    )
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
+
+    while True:
+        dht.run_coroutine(report_status, return_future=False)
+        time.sleep(args.refresh_period)
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
hivemind/utils/__init__.py

@@ -3,7 +3,7 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
-from hivemind.utils.networking import *
+from hivemind.utils.networking import get_free_port, log_visible_maddrs
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.streaming import combine_from_streaming, split_for_streaming

+ 25 - 1
hivemind/utils/networking.py

@@ -1,12 +1,16 @@
 import socket
 from contextlib import closing
 from ipaddress import ip_address
-from typing import Sequence
+from typing import List, Sequence
 
 from multiaddr import Multiaddr
 
+from hivemind.utils.logging import TextStyle, get_logger
+
 LOCALHOST = "127.0.0.1"
 
+logger = get_logger(__name__)
+
 
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """
@@ -52,3 +56,23 @@ def choose_ip_address(
                         return value_for_protocol
 
     raise ValueError(f"No IP address found among given multiaddrs: {maddrs}")
+
+
+def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
+    if only_p2p:
+        unique_addrs = {addr["p2p"] for addr in visible_maddrs}
+        initial_peers = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
+    else:
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
+        if available_ips:
+            preferred_ip = choose_ip_address(available_ips)
+            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
+        else:
+            selected_maddrs = visible_maddrs
+        initial_peers = " ".join(str(addr) for addr in selected_maddrs)
+
+    logger.info(
+        f"Running a DHT instance. To connect other peers to this one, use "
+        f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers}{TextStyle.RESET}"
+    )
+    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")

+ 1 - 0
setup.py

@@ -178,6 +178,7 @@ setup(
     ],
     entry_points={
         "console_scripts": [
+            "hivemind-dht = hivemind.hivemind_cli.run_dht:main",
             "hivemind-server = hivemind.hivemind_cli.run_server:main",
         ]
     },

+ 63 - 0
tests/test_cli_scripts.py

@@ -0,0 +1,63 @@
+import re
+from subprocess import PIPE, Popen
+from time import sleep
+
+DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
+
+
+def test_dht_connection_successful():
+    dht_refresh_period = 1
+
+    dht_proc = Popen(
+        ["hivemind-dht", "--host_maddrs", "/ip4/127.0.0.1/tcp/0", "--refresh_period", str(dht_refresh_period)],
+        stderr=PIPE,
+        text=True,
+        encoding="utf-8",
+    )
+
+    first_line = dht_proc.stderr.readline()
+    second_line = dht_proc.stderr.readline()
+    dht_pattern_match = DHT_START_PATTERN.search(first_line)
+    assert dht_pattern_match is not None, first_line
+    assert "Full list of visible multiaddresses:" in second_line, second_line
+
+    initial_peers = dht_pattern_match.group(1).split(" ")
+
+    dht_client_proc = Popen(
+        ["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"],
+        stderr=PIPE,
+        text=True,
+        encoding="utf-8",
+    )
+
+    # skip first two lines with connectivity info
+    for _ in range(2):
+        dht_client_proc.stderr.readline()
+    first_report_msg = dht_client_proc.stderr.readline()
+
+    assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+
+    # ensure we get the output of dht_proc after the start of dht_client_proc
+    sleep(dht_refresh_period)
+
+    # expect that one of the next logging outputs from the first peer shows a new connection
+    for _ in range(5):
+        first_report_msg = dht_proc.stderr.readline()
+        second_report_msg = dht_proc.stderr.readline()
+
+        if (
+            "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+            and "Local storage contains 0 keys" in second_report_msg
+        ):
+            break
+    else:
+        assert (
+            "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+            and "Local storage contains 0 keys" in second_report_msg
+        )
+
+    dht_proc.terminate()
+    dht_client_proc.terminate()
+
+    dht_proc.wait()
+    dht_client_proc.wait()

+ 1 - 1
tests/test_routing.py

@@ -3,8 +3,8 @@ import operator
 import random
 from itertools import chain, zip_longest
 
-from hivemind import LOCALHOST
 from hivemind.dht.routing import DHTID, RoutingTable
+from hivemind.utils.networking import LOCALHOST
 
 
 def test_ids_basic():