浏览代码

Use public swarm by default (#92)

This PR makes servers and clients use public swarm's bootstrap peers if no other initial peers are specified.

If you'd like a server to start a new swarm, provide the `--new_swarm` CLI argument.
Alexander Borzunov 2 年之前
父节点
当前提交
dc71574a63
共有 6 个文件被更改,包括 46 次插入21 次删除
  1. 10 10
      .github/workflows/run-tests.yaml
  2. 14 5
      cli/run_server.py
  3. 3 2
      src/client/remote_model.py
  4. 8 0
      src/constants.py
  5. 3 0
      src/server/block_selection.py
  6. 8 4
      src/server/server.py

+ 10 - 10
.github/workflows/run-tests.yaml

@@ -72,15 +72,15 @@ jobs:
           export REF_NAME=bigscience/bloom-560m
 
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
-            --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
+            --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
             --torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log &
           SERVER1_PID=$!
-          
+
           sleep 5  # wait for the first server to initialize DHT
-          
+
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
-          
+
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
           SERVER2_PID=$!
@@ -94,20 +94,20 @@ jobs:
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
           SERVER4_PID=$!
-          
+
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server5.log &
           SERVER5_PID=$!
-          
+
           tail -n 100 -f server*.log &
           LOGGER_PID=$!
           sleep 30  # wait for servers to download layers
-          
+
           kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
-          
+
           PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
-          
+
           kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests
-          
+
           kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID
           echo "Done!"

+ 14 - 5
cli/run_server.py

@@ -6,6 +6,7 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from humanfriendly import parse_size
 
+from src.constants import PUBLIC_INITIAL_PEERS
 from src.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
@@ -27,10 +28,10 @@ def main():
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
     parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
                                                                  "use the same name as in the converted model.")
-    parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
-                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port')
     parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
-                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+                        help='Visible multiaddrs the host announces for external connections from other peers')
 
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
 
@@ -71,8 +72,13 @@ def main():
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
-    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
-                        help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
+
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
+                       help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
+    group.add_argument('--new_swarm', action='store_true',
+                       help='Start a new private swarm (i.e., do not connect to any initial peers)')
+
     parser.add_argument('--increase_file_limit', action='store_true',
                         help='On *nix, this will increase the max number of processes '
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
@@ -112,6 +118,9 @@ def main():
         attn_cache_size, (int, type(None))
     ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
 
+    if args.pop("new_swarm"):
+        args["initial_peers"] = []
+
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 

+ 3 - 2
src/client/remote_model.py

@@ -1,5 +1,5 @@
 # this code is in active development, interfaces may change
-from typing import Optional, Tuple
+from typing import List, Optional
 
 import hivemind
 import torch
@@ -17,6 +17,7 @@ from src.bloom.model import (
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
+from src.constants import PUBLIC_INITIAL_PEERS
 from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
@@ -29,7 +30,7 @@ class DistributedBloomConfig(BloomConfig):
     To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
     """
 
-    initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
+    initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU

+ 8 - 0
src/constants.py

@@ -0,0 +1,8 @@
+PUBLIC_INITIAL_PEERS = [
+    "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
+    "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
+    "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
+    "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
+    "/dns/bootstrap3.petals.ml/tcp/31339/p2p/QmX82nfE57CSkNgyEC7pPMPBzjcFLLJXdHhvp1AXKVPvJD",
+    "/dns6/bootstrap3.petals.ml/tcp/31339/p2p/QmX82nfE57CSkNgyEC7pPMPBzjcFLLJXdHhvp1AXKVPvJD",
+]

+ 3 - 0
src/server/block_selection.py

@@ -106,6 +106,9 @@ def should_choose_other_blocks(
             throughputs[span.start : span.end] += span.throughput
 
     new_throughput = throughputs.min()
+    if new_throughput < initial_throughput or new_throughput < eps:
+        return False
+
     actual_quality = initial_throughput / new_throughput
     logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
 

+ 8 - 4
src/server/server.py

@@ -5,7 +5,7 @@ import multiprocessing as mp
 import random
 import threading
 import time
-from typing import Dict, List, Optional, Sequence, Union
+from typing import Dict, List, Optional, Union
 
 import numpy as np
 import psutil
@@ -18,6 +18,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src import BloomConfig, declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
+from src.constants import PUBLIC_INITIAL_PEERS
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.server import block_selection
@@ -39,6 +40,8 @@ class Server(threading.Thread):
 
     def __init__(
         self,
+        *,
+        initial_peers: List[str],
         prefix: Optional[str],
         converted_model_name_or_path: str,
         throughput: Union[float, str],
@@ -53,7 +56,6 @@ class Server(threading.Thread):
         cache_dir: Optional[str] = None,
         attn_cache_size: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
-        initial_peers: Sequence[str] = (),
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
@@ -66,7 +68,6 @@ class Server(threading.Thread):
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
-        *,
         start: bool,
         **kwargs,
     ):
@@ -104,7 +105,10 @@ class Server(threading.Thread):
 
         self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
-        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        if initial_peers == PUBLIC_INITIAL_PEERS:
+            logger.info("Connecting to the public Petals swarm")
+        else:
+            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         self.device = device