Kaynağa Gözat

Use public swarm by default

Aleksandr Borzunov 2 yıl önce
ebeveyn
işleme
60cdd38e27
3 değiştirilmiş dosya ile 19 ekleme ve 9 silme
  1. 14 5
      cli/run_server.py
  2. 3 2
      src/client/remote_model.py
  3. 2 2
      src/server/server.py

+ 14 - 5
cli/run_server.py

@@ -6,6 +6,7 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from humanfriendly import parse_size
 
+from src.constants import PUBLIC_INITIAL_PEERS
 from src.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
@@ -27,10 +28,10 @@ def main():
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
     parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
                                                                  "use the same name as in the converted model.")
-    parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
-                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port')
     parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
-                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+                        help='Visible multiaddrs the host announces for external connections from other peers')
 
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
 
@@ -71,8 +72,13 @@ def main():
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
-    parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
-                        help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
+
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
+                       help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
+    group.add_argument('--new_swarm', action='store_true',
+                       help='Start a new private swarm (i.e., do not connect to any initial peers)')
+
     parser.add_argument('--increase_file_limit', action='store_true',
                         help='On *nix, this will increase the max number of processes '
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
@@ -112,6 +118,9 @@ def main():
         attn_cache_size, (int, type(None))
     ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
 
+    if args.pop("new_swarm"):
+        args.initial_peers = []
+
     use_auth_token = args.pop("use_auth_token")
     args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 

+ 3 - 2
src/client/remote_model.py

@@ -1,5 +1,5 @@
 # this code is in active development, interfaces may change
-from typing import Optional, Tuple
+from typing import Optional, List
 
 import hivemind
 import torch
@@ -15,6 +15,7 @@ from src.bloom.model import (
     BloomPreTrainedModel,
     LMHead,
 )
+from src.constants import PUBLIC_INITIAL_PEERS
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
 from src.utils.misc import DUMMY
@@ -29,7 +30,7 @@ class DistributedBloomConfig(BloomConfig):
     To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
     """
 
-    initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
+    initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU

+ 2 - 2
src/server/server.py

@@ -39,6 +39,8 @@ class Server(threading.Thread):
 
     def __init__(
         self,
+        *,
+        initial_peers: List[str],
         prefix: Optional[str],
         converted_model_name_or_path: str,
         throughput: Union[float, str],
@@ -53,7 +55,6 @@ class Server(threading.Thread):
         cache_dir: Optional[str] = None,
         attn_cache_size: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
-        initial_peers: Sequence[str] = (),
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
@@ -66,7 +67,6 @@ class Server(threading.Thread):
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
-        *,
         start: bool,
         **kwargs,
     ):