Эх сурвалжийг харах

Refactor for v0.9.10 and fix example (#319)

- fixed imports in example
- renamed listen=False -> client_mode=True (currently both options are present, chose client_mode)
- increased default metadata_expiration to better reflect optimal training configuration
- changed CollaborativeCallback backup mechanism to ensure that backups are stored in separate buffers on cpu
- rename throughput -> bandwidth (currently both options are present. Chose bandwidth because throughput also refers to compute, e.g. benchmark_throughput.py)
- re-run examples/albert
- update example outputs in examples/albert/README.md
- added _parent_pid to DHT and *Averager in order to fix incorrect __del__ in some edge cases

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 4 жил өмнө
parent
commit
11db5fd56f

+ 4 - 3
examples/albert/README.md

@@ -28,7 +28,8 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 ```
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:42] Running a DHT peer. To connect other peers to this one, use --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
+[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet, 
+use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Syncing run dry-mountain-2
@@ -61,9 +62,9 @@ To join the collaboration with a GPU trainer,
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
   trainers)
-  collected from the first lines of their terminal output. For the example above, the multiaddresses would be:
+  collected from the first lines of their terminal output. For the example above, the (dummy) multiaddresses would be:
   ```
-  --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
+  --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
   ```
 
   <details>

+ 4 - 1
examples/albert/arguments.py

@@ -69,7 +69,7 @@ class AveragerArguments:
     )
     target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
-        default=30, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+        default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
 
 
@@ -101,6 +101,9 @@ class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArgume
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
+    backup_every_steps: int = field(
+        default=10, metadata={"help": "In case of NaN, training restore from a backup updated with this frequency."}
+    )
 
 
 @dataclass

+ 36 - 17
examples/albert/run_trainer.py

@@ -2,9 +2,10 @@
 
 import logging
 import os
+import pickle
 from dataclasses import asdict
 from pathlib import Path
-from typing import Dict, Any
+from typing import Any
 
 import torch
 import transformers
@@ -18,6 +19,8 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
 import hivemind
+from hivemind.utils.compression import CompressionType
+
 import utils
 from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
 
@@ -93,6 +96,11 @@ def get_optimizer_and_scheduler(training_args, model):
 
 
 class CollaborativeCallback(transformers.TrainerCallback):
+    """
+    This callback monitors and reports collaborative training progress,
+    In case of a catastrophic failure, it can also revert training to a backup
+    """
+
     def __init__(
         self,
         dht: hivemind.DHT,
@@ -100,6 +108,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         model: torch.nn.Module,
         local_public_key: bytes,
         statistics_expiration: float,
+        backup_every_steps: int,
     ):
         super().__init__()
         self.model = model
@@ -107,11 +116,12 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.local_public_key = local_public_key
         self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
-        self.previous_state = self.get_current_state()
         self.samples = 0
         self.steps = 0
         self.loss = 0
         self.total_samples_processed = 0
+        self.backup_every_steps = backup_every_steps
+        self.latest_backup = self.backup_state()
 
     def on_train_begin(
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
@@ -124,9 +134,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
     ):
         control.should_log = True
         if not self.params_are_finite():
-            self.load_from_state(self.previous_state)
+            self.restore_from_backup(self.latest_backup)
             return control
-        self.previous_state = self.get_current_state()
 
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
@@ -146,6 +155,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 if self.steps:
                     logger.info(f"Local loss: {self.loss / self.steps}")
+                if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                    self.latest_backup = self.backup_state()
 
                 self.loss = 0
                 self.steps = 0
@@ -162,15 +173,6 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
         return control
 
-    @torch.no_grad()
-    def get_current_state(self) -> Dict[str, Any]:
-        return {"model": self.model.state_dict(), "opt": self.collaborative_optimizer.opt.state_dict()}
-
-    @torch.no_grad()
-    def load_from_state(self, state):
-        self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["opt"])
-
     @torch.no_grad()
     def params_are_finite(self):
         for param in self.model.parameters():
@@ -178,6 +180,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 return False
         return True
 
+    @torch.no_grad()
+    def backup_state(self) -> Any:
+        return pickle.dumps(
+            {"model": self.model.state_dict(), "training": self.collaborative_optimizer.opt.state_dict()}
+        )
+
+    @torch.no_grad()
+    def restore_from_backup(self, backup):
+        state = pickle.loads(backup)
+        self.model.load_state_dict(state["model"])
+        self.collaborative_optimizer.opt.load_state_dict(state["training"])
+
 
 class NoOpScheduler(LRSchedulerBase):
     """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
@@ -229,7 +243,7 @@ def main():
     dht = hivemind.DHT(
         start=True,
         initial_peers=collaboration_args.initial_peers,
-        listen=not collaboration_args.client_mode,
+        client_mode=collaboration_args.client_mode,
         record_validators=validators,
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
@@ -248,9 +262,9 @@ def main():
         dht=dht,
         scheduler=scheduler,
         prefix=collaboration_args.experiment_prefix,
-        compression_type=hivemind.utils.CompressionType.Value(collaboration_args.compression),
+        compression_type=CompressionType.Value(collaboration_args.compression),
         batch_size_per_step=total_batch_size_per_step,
-        throughput=collaboration_args.bandwidth,
+        bandwidth=collaboration_args.bandwidth,
         target_batch_size=adjusted_target_batch_size,
         client_mode=collaboration_args.client_mode,
         verbose=True,
@@ -274,7 +288,12 @@ def main():
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
         callbacks=[
             CollaborativeCallback(
-                dht, collaborative_optimizer, model, local_public_key, collaboration_args.statistics_expiration
+                dht,
+                collaborative_optimizer,
+                model,
+                local_public_key,
+                collaboration_args.statistics_expiration,
+                collaboration_args.backup_every_steps,
             )
         ],
     )

+ 4 - 2
examples/albert/run_training_monitor.py

@@ -13,6 +13,8 @@ from torch_optimizer import Lamb
 from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
 
 import hivemind
+from hivemind.utils.compression import CompressionType
+
 import utils
 from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 
@@ -99,8 +101,8 @@ class CheckpointHandler:
             opt=opt,
             dht=dht,
             prefix=experiment_prefix,
-            compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
-            throughput=collab_optimizer_args.bandwidth,
+            compression_type=CompressionType.Value(collab_optimizer_args.compression),
+            bandwidth=collab_optimizer_args.bandwidth,
             target_batch_size=adjusted_target_batch_size,
             client_mode=collab_optimizer_args.client_mode,
             verbose=True,

+ 12 - 2
examples/albert/utils.py

@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple
 from multiaddr import Multiaddr
 from pydantic import BaseModel, StrictFloat, confloat, conint
 
+from hivemind import choose_ip_address
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
@@ -41,8 +42,17 @@ def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}
         initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
     else:
-        initial_peers_str = " ".join(str(addr) for addr in visible_maddrs)
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr]
+        available_ips += [Multiaddr(addr) for addr in visible_maddrs if "ip6" in addr]
+        if available_ips:
+            preferred_ip = choose_ip_address(available_ips)
+            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
+        else:
+            selected_maddrs = visible_maddrs
+        initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
+
     logger.info(
-        f"Running a DHT peer. To connect other peers to this one, use "
+        f"Running a DHT peer. To connect other peers to this one over the Internet, use "
         f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
     )
+    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")

+ 32 - 31
hivemind/averaging/averager.py

@@ -31,7 +31,7 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
-from hivemind.utils.networking import choose_ip_address, strip_port
+from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 
@@ -64,11 +64,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
-    :param throughput: if specified, this value represents the network bandwidth available to averager.
+    :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
-          If throughput == 0, averager will rely on its groupmates to do all the averaging.
-    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
-            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
+          If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
+    :param client_mode: if False (default), this averager will accept incoming requests from other peers
+            if True, the averager will only join existing groups where at least one peer has client_mode=False
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param announced_host: visible IP address the averager will announce for external connections from other peers.
           If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
@@ -115,11 +115,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
         compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-        throughput: Optional[float] = None,
+        bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
-        listen: bool = True,
+        client_mode: bool = False,
         listen_on: Endpoint = "0.0.0.0:*",
         daemon: bool = True,
         announced_host: Optional[str] = None,
@@ -128,18 +128,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         **kwargs,
     ):
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
-        assert throughput is None or (
-            throughput >= 0 and np.isfinite(np.float32(throughput))
-        ), "throughput must be a non-negative float32"
+        assert bandwidth is None or (
+            bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
+        ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
+        assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
         self.dht = dht
-        self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
-        if not self.listen:
+        self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
+        self._parent_pid = os.getpid()
+        if self.client_mode:
             self.mode = AveragingMode.CLIENT
         elif auxiliary:
             self.mode = AveragingMode.AUX
@@ -161,7 +162,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
-        self.throughput = throughput
+        self.bandwidth = bandwidth
 
         self.matchmaking_kwargs = dict(
             prefix=prefix,
@@ -181,10 +182,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
-        self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
+        if allow_state_sharing is None:
+            allow_state_sharing = not client_mode and not auxiliary
+        self.allow_state_sharing = allow_state_sharing
 
         self._averager_endpoint: Optional[Endpoint] = None
-        if not self.listen:
+        if self.client_mode:
             self._averager_endpoint = f"client::{uuid.uuid4()}"
 
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
@@ -221,16 +224,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     @allow_state_sharing.setter
     def allow_state_sharing(self, value: bool):
-        if value is True and not self.listen:
-            logger.warning(
-                "Cannot allow state sharing: averager in client mode (listen=False) cannot share its state."
-            )
+        if value and self.client_mode:
+            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
         else:
             self._allow_state_sharing.value = value
 
     @property
     def endpoint(self) -> Optional[Endpoint]:
-        if self.listen and self._averager_endpoint is None:
+        if self._averager_endpoint is None and not self.client_mode:
             assert self.port is not None, "Averager is not running yet"
             self._averager_endpoint = f"{self.announced_host}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
@@ -258,7 +259,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             async def _run():
                 grpc.aio.init_grpc_aio()
 
-                if self.listen:
+                if not self.client_mode:
                     self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
                     averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
                     found_port = self._server.add_insecure_port(self.listen_on)
@@ -269,9 +270,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     logger.debug(f"The averager is running in client mode.")
 
                 self._matchmaking = Matchmaking(
-                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=not self.listen
+                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
                 )
-                if self.listen:
+                if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
 
                 self._pending_group_assembled = asyncio.Event()
@@ -312,7 +313,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        if self.listen:
+        if not self.client_mode:
             remaining_tasks.add(self._server.stop(timeout))
         await asyncio.gather(*remaining_tasks)
 
@@ -374,7 +375,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             while not future.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.throughput, self.mode.value, gather_binary])
+                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                     )
@@ -422,16 +423,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
 
-            # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
-            incoming_throughputs = [
-                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)
+            # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
             ]
             peer_fractions = await asyncio.get_event_loop().run_in_executor(
-                None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
 
             async with self.get_tensors_async() as local_tensors:

+ 21 - 21
hivemind/averaging/load_balancing.py

@@ -9,30 +9,30 @@ logger = get_logger(__name__)
 LOAD_BALANCING_LP_DECIMALS = 9
 
 
-def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
+def load_balance_peers(vector_size, bandwidths: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
     """
-    Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
+    Find an optimal partitioning of weights for butterfly all-reduce given peer bandwidths.
     :param vector_size: total size of the averaged vector (in elements, not bytes)
-    :param throughputs: 1d array of non-negative throughputs for each peer capable of averaging
+    :param bandwidths: 1d array of non-negative bandwidths for each peer capable of averaging
       zeros stand for client-only participants, None represents "not specified" (resolved as mean of other pears)
     :param min_size: peers that can aggregate less than this many elements will be assigned nothing
     :returns: an integer array where i-th element is the number of weights assigned to i-th peer
     """
-    specified_throughputs = [throughput for throughput in throughputs if throughput is not None and throughput > 0]
+    specified_bandwidth = [item for item in bandwidths if item is not None and item > 0]
 
-    if specified_throughputs:
-        default_throughput = np.mean(specified_throughputs)
-        throughputs = [throughput if throughput is not None else default_throughput for throughput in throughputs]
-        scores = optimize_parts_lp(vector_size, np.asarray(throughputs), min_size)
+    if specified_bandwidth:
+        default_bandwidth = np.mean(specified_bandwidth)
+        bandwidths = [item if item is not None else default_bandwidth for item in bandwidths]
+        scores = optimize_parts_lp(vector_size, np.asarray(bandwidths), min_size)
     else:
-        assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
-        scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
+        assert not all(item == 0 for item in bandwidths), "Must have at least one nonzero bandwidth"
+        scores = np.asarray([1.0 if item is None else 0.0 for item in bandwidths])
 
     # TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 
-def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int = 0) -> np.ndarray:
+def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int = 0) -> np.ndarray:
     """
     This method solves an optimization problem to minimize the total allreduce time.
     In butterfly all-reduce, each peer acts both as a "client" and as an "aggregator":
@@ -42,20 +42,20 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
     Peer i network load as a "client" = vector_size * (1 - fraction_assigned_to_peer_i)
     Peer i network load as an "aggregator" = vector_size * (group_size - 1) * fraction_assigned_to_peer_i
     Peer i total communication = vector_size * [1 + (group_size - 2) * fraction_assigned_to_peer_i]
-    Total time = max_i (total_communication_for_peer_i / throughputs[i])
+    Total time = max_i (total_communication_for_peer_i / bandwidths[i])
 
     We solve this optimization problem by reducing it to linear programming with a minimax reduction
     (see lecture notes: https://www.usna.edu/Users/math/dphillip/sa305.s15/phillips/lessons/32/32.pdf )
 
     :returns: a vector of "scores", i-th score is proportional to the fraction of weights assigned to i-th peer
     """
-    assert np.all(throughputs >= 0) and np.any(throughputs > 0)
-    throughputs = np.asarray(throughputs, dtype=np.float64)
-    permutation = np.argsort(-throughputs)
-    throughputs = throughputs[permutation]
-    is_nonzero = throughputs != 0
+    assert np.all(bandwidths >= 0) and np.any(bandwidths > 0)
+    bandwidths = np.asarray(bandwidths, dtype=np.float64)
+    permutation = np.argsort(-bandwidths)
+    bandwidths = bandwidths[permutation]
+    is_nonzero = bandwidths != 0
 
-    group_size = len(throughputs)
+    group_size = len(bandwidths)
     num_variables = group_size + 1  # [w_1, ..., w_N, xi]
 
     c = np.zeros(num_variables, dtype=np.float64)
@@ -64,9 +64,9 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
     # the constraints below are tuples (A, b) such that Ax <= b
     nonnegative_weights = -np.eye(group_size, num_variables, dtype=c.dtype), np.zeros(group_size, c.dtype)
     weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
-    coeff_per_variable = (group_size - 2.0) / np.maximum(throughputs, 10 ** -LOAD_BALANCING_LP_DECIMALS)
+    coeff_per_variable = (group_size - 2.0) / np.maximum(bandwidths, 10 ** -LOAD_BALANCING_LP_DECIMALS)
     coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1), c.dtype)])
-    xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / throughputs[is_nonzero]
+    xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / bandwidths[is_nonzero]
     force_max_weights = np.eye(group_size, M=num_variables, dtype=c.dtype), is_nonzero.astype(c.dtype)
 
     A, b = list(map(np.concatenate, zip(nonnegative_weights, weights_sum_to_one, xi_is_maximum, force_max_weights)))
@@ -79,7 +79,7 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
             peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
         peer_scores = np.round(peer_scores, LOAD_BALANCING_LP_DECIMALS)
     else:
-        logger.error(f"Failed to solve load-balancing for bandwidths {throughputs}.")
+        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}.")
         peer_scores = np.ones(group_size, c.dtype)
 
     return peer_scores[np.argsort(permutation)]

+ 1 - 7
hivemind/dht/__init__.py

@@ -26,7 +26,6 @@ from multiaddr import Multiaddr
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.dht.routing import DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.p2p import P2P
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
@@ -40,9 +39,6 @@ class DHT(mp.Process):
     * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
     * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
 
-    :param p2p: instance of hivemind.p2p.P2P that will be used for communication.
-      If None, DHTNode will create and manage its own P2P instance with given initial_peers and
-      parameters from ``kwargs``
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
@@ -60,7 +56,6 @@ class DHT(mp.Process):
 
     def __init__(
         self,
-        p2p: Optional[P2P] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         start: bool,
@@ -70,9 +65,9 @@ class DHT(mp.Process):
         shutdown_timeout: float = 3,
         **kwargs,
     ):
+        self._parent_pid = os.getpid()
         super().__init__()
 
-        self.p2p = p2p
         if not (
             initial_peers is None
             or (
@@ -101,7 +96,6 @@ class DHT(mp.Process):
 
             async def _run():
                 self._node = await DHTNode.create(
-                    p2p=self.p2p,
                     initial_peers=self.initial_peers,
                     num_workers=self.max_workers or 1,
                     record_validator=self._record_validator,

+ 4 - 4
hivemind/dht/node.py

@@ -114,7 +114,7 @@ class DHTNode:
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
-        listen: bool = True,
+        client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
         validate: bool = True,
@@ -154,8 +154,8 @@ class DHTNode:
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
-          if False, this node will refuse any incoming request, effectively being only a "client"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+          if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
           for a given authorization protocol
@@ -203,7 +203,7 @@ class DHTNode:
             wait_timeout,
             parallel_rpc,
             cache_size,
-            listen,
+            client_mode,
             record_validator,
             authorizer,
         )

+ 6 - 6
hivemind/dht/protocol.py

@@ -43,7 +43,7 @@ class DHTProtocol(ServicerBase):
         wait_timeout: float,
         parallel_rpc: Optional[int] = None,
         cache_size: Optional[int] = None,
-        listen=True,
+        client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
     ) -> DHTProtocol:
@@ -66,15 +66,15 @@ class DHTProtocol(ServicerBase):
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float("inf"))
-        self.listen = listen
+        self.client_mode = client_mode
         self.record_validator = record_validator
         self.authorizer = authorizer
 
-        if listen:
+        if not client_mode:
             await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
 
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
-        else:  # client-only mode
+        else:
             # note: use empty node_info so peers won't add you to their routing tables
             self.node_info = dht_pb2.NodeInfo()
         return self
@@ -95,7 +95,7 @@ class DHTProtocol(ServicerBase):
         :param peer: peer ID to ping
         :param validate: if True, validates that node's peer_id is available
         :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
-        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
+        :note: if DHTProtocol was created with client_mode=False, also request peer to add you to his routing table
 
         :return: node's DHTID, if peer responded and decided to send his node_id
         """
@@ -112,7 +112,7 @@ class DHTProtocol(ServicerBase):
 
         if responded and validate:
             try:
-                if self.listen and not response.available:
+                if not self.client_mode and not response.available:
                     raise ValidationError(
                         f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
                     )

+ 1 - 1
hivemind/optim/adaptive.py

@@ -29,6 +29,6 @@ class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
             average_opt_statistics=average_opt_statistics,
             prefix=f"{self.prefix}_averaging",
             allreduce_timeout=self.averaging_timeout,
-            listen=not self.client_mode,
+            client_mode=self.client_mode,
             **kwargs,
         )

+ 2 - 2
hivemind/optim/collaborative.py

@@ -167,7 +167,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             average_gradients=True,
             prefix=f"{self.prefix}_averaging",
             allreduce_timeout=self.averaging_timeout,
-            listen=not self.client_mode,
+            client_mode=self.client_mode,
             **kwargs,
         )
 
@@ -359,7 +359,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,
                     time=current_time,
-                    client_mode=not self.averager.listen,
+                    client_mode=self.averager.client_mode,
                 )
 
             self.dht.store(

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -70,7 +70,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
-    SOFT_UPDATE_TIMEOUT = 0.1  # seconds spent awaiting status update before warning is printed
+    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
     HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
 
     def __init__(self, *, synchronize: bool = True, use_lock: bool = True):

+ 13 - 13
tests/test_averaging.py

@@ -76,7 +76,7 @@ def _test_allreduce_once(n_clients, n_aux):
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
-            listen=mode != AveragingMode.CLIENT,
+            client_mode=mode == AveragingMode.CLIENT,
             listen_on="127.0.0.1:*",
             auxiliary=mode == AveragingMode.AUX,
             start=True,
@@ -121,8 +121,8 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht = hivemind.DHT(start=True)
 
     n_peers = 4
-    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
-    random.shuffle(should_listen)
+    client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
+    random.shuffle(client_modes)
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
@@ -135,11 +135,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
-            listen=listen,
+            client_mode=client_mode,
             listen_on="127.0.0.1:*",
             start=True,
         )
-        for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)
+        for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
     ]
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     reference = [
@@ -180,7 +180,7 @@ def test_allreduce_compression():
             [x.clone() for x in tensors1],
             dht=dht,
             compression_type=compression_type_pair,
-            listen=False,
+            client_mode=True,
             target_group_size=2,
             prefix="mygroup",
             start=True,
@@ -306,16 +306,16 @@ def test_allgather():
     dht.shutdown()
 
 
-def get_cost(vector_size, partitions, throughputs):
+def get_cost(vector_size, partitions, bandwidths):
     return max(
-        (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
+        (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
         for i in range(len(partitions))
     )
 
 
-def check_optimality(vector_size, throughputs, ref_partitions):
-    partitions = list(load_balance_peers(vector_size, throughputs))
-    assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
+def check_optimality(vector_size, bandwidths, ref_partitions):
+    partitions = list(load_balance_peers(vector_size, bandwidths))
+    assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
 
 
 @pytest.mark.forked
@@ -342,9 +342,9 @@ def test_load_balancing():
         vector_size = np.random.randint(1, 1024 ** 3)
         num_peers = np.random.randint(1, 256)
         scale = 1e-9 + np.random.rand() * 1e5
-        throughputs = np.random.rand(num_peers) * scale + 1e-6
+        bandwidths = np.random.rand(num_peers) * scale + 1e-6
         min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
-        assignment = load_balance_peers(vector_size, throughputs, min_size)
+        assignment = load_balance_peers(vector_size, bandwidths, min_size)
         assert np.sum(assignment) == vector_size
         assert np.min(assignment) >= 0
 

+ 1 - 1
tests/test_dht.py

@@ -102,7 +102,7 @@ async def test_dht_get_visible_maddrs():
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(p2p, start=True)
+    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
     dht.shutdown()

+ 6 - 5
tests/test_dht_node.py

@@ -77,11 +77,12 @@ def test_dht_protocol():
     peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
 
     loop = asyncio.get_event_loop()
-    for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
+        peer_id = DHTID.generate()
         p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
         protocol = loop.run_until_complete(
             DHTProtocol.create(
-                p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen
+                p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
             )
         )
         logger.info(f"Self id={protocol.node_id}")
@@ -150,7 +151,7 @@ def test_dht_protocol():
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
-        if listen:
+        if not client_mode:
             loop.run_until_complete(p2p.shutdown())
 
     peer1_proc.terminate()
@@ -166,7 +167,7 @@ def test_empty_table():
     p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
     protocol = loop.run_until_complete(
         DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False
+            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
         )
     )
 
@@ -353,7 +354,7 @@ async def test_dhtnode_caching(T=0.05):
     node1 = await DHTNode.create(
         initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
         cache_refresh_before_expiry=5 * T,
-        listen=False,
+        client_mode=True,
         reuse_get_requests=False,
     )
     await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)