瀏覽代碼

Refactor for v0.9.10 and fix example (#319)

- fixed imports in example
- renamed listen=False -> client_mode=True (currently both options are present, chose client_mode)
- increased default metadata_expiration to better reflect optimal training configuration
- changed CollaborativeCallback backup mechanism to ensure that backups are stored in separate buffers on cpu
- rename throughput -> bandwidth (currently both options are present. Chose bandwidth because throughput also refers to compute, e.g. benchmark_throughput.py)
- re-run examples/albert
- update example outputs in examples/albert/README.md
- added _parent_pid to DHT and *Averager in order to fix incorrect __del__ in some edge cases

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 4 年之前
父節點
當前提交
11db5fd56f

+ 4 - 3
examples/albert/README.md

@@ -28,7 +28,8 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 
 ```
 ```
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:42] Running a DHT peer. To connect other peers to this one, use --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
+[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet, 
+use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Tracking run with wandb version 0.10.32
 wandb: Syncing run dry-mountain-2
 wandb: Syncing run dry-mountain-2
@@ -61,9 +62,9 @@ To join the collaboration with a GPU trainer,
 
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
   trainers)
   trainers)
-  collected from the first lines of their terminal output. For the example above, the multiaddresses would be:
+  collected from the first lines of their terminal output. For the example above, the (dummy) multiaddresses would be:
   ```
   ```
-  --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
+  --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
   ```
   ```
 
 
   <details>
   <details>

+ 4 - 1
examples/albert/arguments.py

@@ -69,7 +69,7 @@ class AveragerArguments:
     )
     )
     target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
     metadata_expiration: float = field(
-        default=30, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+        default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
     )
 
 
 
 
@@ -101,6 +101,9 @@ class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArgume
     statistics_expiration: float = field(
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
     )
+    backup_every_steps: int = field(
+        default=10, metadata={"help": "In case of NaN, training restore from a backup updated with this frequency."}
+    )
 
 
 
 
 @dataclass
 @dataclass

+ 36 - 17
examples/albert/run_trainer.py

@@ -2,9 +2,10 @@
 
 
 import logging
 import logging
 import os
 import os
+import pickle
 from dataclasses import asdict
 from dataclasses import asdict
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, Any
+from typing import Any
 
 
 import torch
 import torch
 import transformers
 import transformers
@@ -18,6 +19,8 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 from transformers.trainer_utils import is_main_process
 
 
 import hivemind
 import hivemind
+from hivemind.utils.compression import CompressionType
+
 import utils
 import utils
 from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
 from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
 
 
@@ -93,6 +96,11 @@ def get_optimizer_and_scheduler(training_args, model):
 
 
 
 
 class CollaborativeCallback(transformers.TrainerCallback):
 class CollaborativeCallback(transformers.TrainerCallback):
+    """
+    This callback monitors and reports collaborative training progress,
+    In case of a catastrophic failure, it can also revert training to a backup
+    """
+
     def __init__(
     def __init__(
         self,
         self,
         dht: hivemind.DHT,
         dht: hivemind.DHT,
@@ -100,6 +108,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         model: torch.nn.Module,
         model: torch.nn.Module,
         local_public_key: bytes,
         local_public_key: bytes,
         statistics_expiration: float,
         statistics_expiration: float,
+        backup_every_steps: int,
     ):
     ):
         super().__init__()
         super().__init__()
         self.model = model
         self.model = model
@@ -107,11 +116,12 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.local_public_key = local_public_key
         self.local_public_key = local_public_key
         self.statistics_expiration = statistics_expiration
         self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
         self.last_reported_collaboration_step = -1
-        self.previous_state = self.get_current_state()
         self.samples = 0
         self.samples = 0
         self.steps = 0
         self.steps = 0
         self.loss = 0
         self.loss = 0
         self.total_samples_processed = 0
         self.total_samples_processed = 0
+        self.backup_every_steps = backup_every_steps
+        self.latest_backup = self.backup_state()
 
 
     def on_train_begin(
     def on_train_begin(
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
@@ -124,9 +134,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
     ):
     ):
         control.should_log = True
         control.should_log = True
         if not self.params_are_finite():
         if not self.params_are_finite():
-            self.load_from_state(self.previous_state)
+            self.restore_from_backup(self.latest_backup)
             return control
             return control
-        self.previous_state = self.get_current_state()
 
 
         if state.log_history:
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
             self.loss += state.log_history[-1]["loss"]
@@ -146,6 +155,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 if self.steps:
                 if self.steps:
                     logger.info(f"Local loss: {self.loss / self.steps}")
                     logger.info(f"Local loss: {self.loss / self.steps}")
+                if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                    self.latest_backup = self.backup_state()
 
 
                 self.loss = 0
                 self.loss = 0
                 self.steps = 0
                 self.steps = 0
@@ -162,15 +173,6 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
 
         return control
         return control
 
 
-    @torch.no_grad()
-    def get_current_state(self) -> Dict[str, Any]:
-        return {"model": self.model.state_dict(), "opt": self.collaborative_optimizer.opt.state_dict()}
-
-    @torch.no_grad()
-    def load_from_state(self, state):
-        self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["opt"])
-
     @torch.no_grad()
     @torch.no_grad()
     def params_are_finite(self):
     def params_are_finite(self):
         for param in self.model.parameters():
         for param in self.model.parameters():
@@ -178,6 +180,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 return False
                 return False
         return True
         return True
 
 
+    @torch.no_grad()
+    def backup_state(self) -> Any:
+        return pickle.dumps(
+            {"model": self.model.state_dict(), "training": self.collaborative_optimizer.opt.state_dict()}
+        )
+
+    @torch.no_grad()
+    def restore_from_backup(self, backup):
+        state = pickle.loads(backup)
+        self.model.load_state_dict(state["model"])
+        self.collaborative_optimizer.opt.load_state_dict(state["training"])
+
 
 
 class NoOpScheduler(LRSchedulerBase):
 class NoOpScheduler(LRSchedulerBase):
     """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
     """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
@@ -229,7 +243,7 @@ def main():
     dht = hivemind.DHT(
     dht = hivemind.DHT(
         start=True,
         start=True,
         initial_peers=collaboration_args.initial_peers,
         initial_peers=collaboration_args.initial_peers,
-        listen=not collaboration_args.client_mode,
+        client_mode=collaboration_args.client_mode,
         record_validators=validators,
         record_validators=validators,
         use_ipfs=collaboration_args.use_ipfs,
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
         host_maddrs=collaboration_args.host_maddrs,
@@ -248,9 +262,9 @@ def main():
         dht=dht,
         dht=dht,
         scheduler=scheduler,
         scheduler=scheduler,
         prefix=collaboration_args.experiment_prefix,
         prefix=collaboration_args.experiment_prefix,
-        compression_type=hivemind.utils.CompressionType.Value(collaboration_args.compression),
+        compression_type=CompressionType.Value(collaboration_args.compression),
         batch_size_per_step=total_batch_size_per_step,
         batch_size_per_step=total_batch_size_per_step,
-        throughput=collaboration_args.bandwidth,
+        bandwidth=collaboration_args.bandwidth,
         target_batch_size=adjusted_target_batch_size,
         target_batch_size=adjusted_target_batch_size,
         client_mode=collaboration_args.client_mode,
         client_mode=collaboration_args.client_mode,
         verbose=True,
         verbose=True,
@@ -274,7 +288,12 @@ def main():
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
         callbacks=[
         callbacks=[
             CollaborativeCallback(
             CollaborativeCallback(
-                dht, collaborative_optimizer, model, local_public_key, collaboration_args.statistics_expiration
+                dht,
+                collaborative_optimizer,
+                model,
+                local_public_key,
+                collaboration_args.statistics_expiration,
+                collaboration_args.backup_every_steps,
             )
             )
         ],
         ],
     )
     )

+ 4 - 2
examples/albert/run_training_monitor.py

@@ -13,6 +13,8 @@ from torch_optimizer import Lamb
 from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
 from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
 
 
 import hivemind
 import hivemind
+from hivemind.utils.compression import CompressionType
+
 import utils
 import utils
 from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 
 
@@ -99,8 +101,8 @@ class CheckpointHandler:
             opt=opt,
             opt=opt,
             dht=dht,
             dht=dht,
             prefix=experiment_prefix,
             prefix=experiment_prefix,
-            compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
-            throughput=collab_optimizer_args.bandwidth,
+            compression_type=CompressionType.Value(collab_optimizer_args.compression),
+            bandwidth=collab_optimizer_args.bandwidth,
             target_batch_size=adjusted_target_batch_size,
             target_batch_size=adjusted_target_batch_size,
             client_mode=collab_optimizer_args.client_mode,
             client_mode=collab_optimizer_args.client_mode,
             verbose=True,
             verbose=True,

+ 12 - 2
examples/albert/utils.py

@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 from pydantic import BaseModel, StrictFloat, confloat, conint
 from pydantic import BaseModel, StrictFloat, confloat, conint
 
 
+from hivemind import choose_ip_address
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.dht.validation import RecordValidatorBase
@@ -41,8 +42,17 @@ def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}
         initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
         initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
     else:
     else:
-        initial_peers_str = " ".join(str(addr) for addr in visible_maddrs)
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr]
+        available_ips += [Multiaddr(addr) for addr in visible_maddrs if "ip6" in addr]
+        if available_ips:
+            preferred_ip = choose_ip_address(available_ips)
+            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
+        else:
+            selected_maddrs = visible_maddrs
+        initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
+
     logger.info(
     logger.info(
-        f"Running a DHT peer. To connect other peers to this one, use "
+        f"Running a DHT peer. To connect other peers to this one over the Internet, use "
         f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
         f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
     )
     )
+    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")

+ 32 - 31
hivemind/averaging/averager.py

@@ -31,7 +31,7 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
 from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
-from hivemind.utils.networking import choose_ip_address, strip_port
+from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 
 
@@ -64,11 +64,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
-    :param throughput: if specified, this value represents the network bandwidth available to averager.
+    :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
-          If throughput == 0, averager will rely on its groupmates to do all the averaging.
-    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
-            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
+          If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
+    :param client_mode: if False (default), this averager will accept incoming requests from other peers
+            if True, the averager will only join existing groups where at least one peer has client_mode=False
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param announced_host: visible IP address the averager will announce for external connections from other peers.
     :param announced_host: visible IP address the averager will announce for external connections from other peers.
           If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
           If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
@@ -115,11 +115,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
         allreduce_timeout: Optional[float] = None,
         compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
         compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-        throughput: Optional[float] = None,
+        bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
         allow_state_sharing: Optional[bool] = None,
-        listen: bool = True,
+        client_mode: bool = False,
         listen_on: Endpoint = "0.0.0.0:*",
         listen_on: Endpoint = "0.0.0.0:*",
         daemon: bool = True,
         daemon: bool = True,
         announced_host: Optional[str] = None,
         announced_host: Optional[str] = None,
@@ -128,18 +128,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         **kwargs,
         **kwargs,
     ):
     ):
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
-        assert throughput is None or (
-            throughput >= 0 and np.isfinite(np.float32(throughput))
-        ), "throughput must be a non-negative float32"
+        assert bandwidth is None or (
+            bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
+        ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
         assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
+        assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
-        self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
-        if not self.listen:
+        self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
+        self._parent_pid = os.getpid()
+        if self.client_mode:
             self.mode = AveragingMode.CLIENT
             self.mode = AveragingMode.CLIENT
         elif auxiliary:
         elif auxiliary:
             self.mode = AveragingMode.AUX
             self.mode = AveragingMode.AUX
@@ -161,7 +162,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
-        self.throughput = throughput
+        self.bandwidth = bandwidth
 
 
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
             prefix=prefix,
             prefix=prefix,
@@ -181,10 +182,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
-        self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
+        if allow_state_sharing is None:
+            allow_state_sharing = not client_mode and not auxiliary
+        self.allow_state_sharing = allow_state_sharing
 
 
         self._averager_endpoint: Optional[Endpoint] = None
         self._averager_endpoint: Optional[Endpoint] = None
-        if not self.listen:
+        if self.client_mode:
             self._averager_endpoint = f"client::{uuid.uuid4()}"
             self._averager_endpoint = f"client::{uuid.uuid4()}"
 
 
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
@@ -221,16 +224,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     @allow_state_sharing.setter
     @allow_state_sharing.setter
     def allow_state_sharing(self, value: bool):
     def allow_state_sharing(self, value: bool):
-        if value is True and not self.listen:
-            logger.warning(
-                "Cannot allow state sharing: averager in client mode (listen=False) cannot share its state."
-            )
+        if value and self.client_mode:
+            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
         else:
         else:
             self._allow_state_sharing.value = value
             self._allow_state_sharing.value = value
 
 
     @property
     @property
     def endpoint(self) -> Optional[Endpoint]:
     def endpoint(self) -> Optional[Endpoint]:
-        if self.listen and self._averager_endpoint is None:
+        if self._averager_endpoint is None and not self.client_mode:
             assert self.port is not None, "Averager is not running yet"
             assert self.port is not None, "Averager is not running yet"
             self._averager_endpoint = f"{self.announced_host}:{self.port}"
             self._averager_endpoint = f"{self.announced_host}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
@@ -258,7 +259,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             async def _run():
             async def _run():
                 grpc.aio.init_grpc_aio()
                 grpc.aio.init_grpc_aio()
 
 
-                if self.listen:
+                if not self.client_mode:
                     self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
                     self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
                     averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
                     averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
                     found_port = self._server.add_insecure_port(self.listen_on)
                     found_port = self._server.add_insecure_port(self.listen_on)
@@ -269,9 +270,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     logger.debug(f"The averager is running in client mode.")
                     logger.debug(f"The averager is running in client mode.")
 
 
                 self._matchmaking = Matchmaking(
                 self._matchmaking = Matchmaking(
-                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=not self.listen
+                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
                 )
                 )
-                if self.listen:
+                if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
 
 
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
@@ -312,7 +313,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         remaining_tasks = set()
         remaining_tasks = set()
         for group in self._running_groups.values():
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
             remaining_tasks.update(group.finalize(cancel=True))
-        if self.listen:
+        if not self.client_mode:
             remaining_tasks.add(self._server.stop(timeout))
             remaining_tasks.add(self._server.stop(timeout))
         await asyncio.gather(*remaining_tasks)
         await asyncio.gather(*remaining_tasks)
 
 
@@ -374,7 +375,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             while not future.done():
             while not future.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.throughput, self.mode.value, gather_binary])
+                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                         timeout=timeout, data_for_gather=data_for_gather
                     )
                     )
@@ -422,16 +423,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
-            weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
 
 
-            # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
-            incoming_throughputs = [
-                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)
+            # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
             ]
             ]
             peer_fractions = await asyncio.get_event_loop().run_in_executor(
             peer_fractions = await asyncio.get_event_loop().run_in_executor(
-                None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
             )
 
 
             async with self.get_tensors_async() as local_tensors:
             async with self.get_tensors_async() as local_tensors:

+ 21 - 21
hivemind/averaging/load_balancing.py

@@ -9,30 +9,30 @@ logger = get_logger(__name__)
 LOAD_BALANCING_LP_DECIMALS = 9
 LOAD_BALANCING_LP_DECIMALS = 9
 
 
 
 
-def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
+def load_balance_peers(vector_size, bandwidths: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
     """
     """
-    Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
+    Find an optimal partitioning of weights for butterfly all-reduce given peer bandwidths.
     :param vector_size: total size of the averaged vector (in elements, not bytes)
     :param vector_size: total size of the averaged vector (in elements, not bytes)
-    :param throughputs: 1d array of non-negative throughputs for each peer capable of averaging
+    :param bandwidths: 1d array of non-negative bandwidths for each peer capable of averaging
       zeros stand for client-only participants, None represents "not specified" (resolved as mean of other pears)
       zeros stand for client-only participants, None represents "not specified" (resolved as mean of other pears)
     :param min_size: peers that can aggregate less than this many elements will be assigned nothing
     :param min_size: peers that can aggregate less than this many elements will be assigned nothing
     :returns: an integer array where i-th element is the number of weights assigned to i-th peer
     :returns: an integer array where i-th element is the number of weights assigned to i-th peer
     """
     """
-    specified_throughputs = [throughput for throughput in throughputs if throughput is not None and throughput > 0]
+    specified_bandwidth = [item for item in bandwidths if item is not None and item > 0]
 
 
-    if specified_throughputs:
-        default_throughput = np.mean(specified_throughputs)
-        throughputs = [throughput if throughput is not None else default_throughput for throughput in throughputs]
-        scores = optimize_parts_lp(vector_size, np.asarray(throughputs), min_size)
+    if specified_bandwidth:
+        default_bandwidth = np.mean(specified_bandwidth)
+        bandwidths = [item if item is not None else default_bandwidth for item in bandwidths]
+        scores = optimize_parts_lp(vector_size, np.asarray(bandwidths), min_size)
     else:
     else:
-        assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
-        scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
+        assert not all(item == 0 for item in bandwidths), "Must have at least one nonzero bandwidth"
+        scores = np.asarray([1.0 if item is None else 0.0 for item in bandwidths])
 
 
     # TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     # TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 
 
 
-def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int = 0) -> np.ndarray:
+def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int = 0) -> np.ndarray:
     """
     """
     This method solves an optimization problem to minimize the total allreduce time.
     This method solves an optimization problem to minimize the total allreduce time.
     In butterfly all-reduce, each peer acts both as a "client" and as an "aggregator":
     In butterfly all-reduce, each peer acts both as a "client" and as an "aggregator":
@@ -42,20 +42,20 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
     Peer i network load as a "client" = vector_size * (1 - fraction_assigned_to_peer_i)
     Peer i network load as a "client" = vector_size * (1 - fraction_assigned_to_peer_i)
     Peer i network load as an "aggregator" = vector_size * (group_size - 1) * fraction_assigned_to_peer_i
     Peer i network load as an "aggregator" = vector_size * (group_size - 1) * fraction_assigned_to_peer_i
     Peer i total communication = vector_size * [1 + (group_size - 2) * fraction_assigned_to_peer_i]
     Peer i total communication = vector_size * [1 + (group_size - 2) * fraction_assigned_to_peer_i]
-    Total time = max_i (total_communication_for_peer_i / throughputs[i])
+    Total time = max_i (total_communication_for_peer_i / bandwidths[i])
 
 
     We solve this optimization problem by reducing it to linear programming with a minimax reduction
     We solve this optimization problem by reducing it to linear programming with a minimax reduction
     (see lecture notes: https://www.usna.edu/Users/math/dphillip/sa305.s15/phillips/lessons/32/32.pdf )
     (see lecture notes: https://www.usna.edu/Users/math/dphillip/sa305.s15/phillips/lessons/32/32.pdf )
 
 
     :returns: a vector of "scores", i-th score is proportional to the fraction of weights assigned to i-th peer
     :returns: a vector of "scores", i-th score is proportional to the fraction of weights assigned to i-th peer
     """
     """
-    assert np.all(throughputs >= 0) and np.any(throughputs > 0)
-    throughputs = np.asarray(throughputs, dtype=np.float64)
-    permutation = np.argsort(-throughputs)
-    throughputs = throughputs[permutation]
-    is_nonzero = throughputs != 0
+    assert np.all(bandwidths >= 0) and np.any(bandwidths > 0)
+    bandwidths = np.asarray(bandwidths, dtype=np.float64)
+    permutation = np.argsort(-bandwidths)
+    bandwidths = bandwidths[permutation]
+    is_nonzero = bandwidths != 0
 
 
-    group_size = len(throughputs)
+    group_size = len(bandwidths)
     num_variables = group_size + 1  # [w_1, ..., w_N, xi]
     num_variables = group_size + 1  # [w_1, ..., w_N, xi]
 
 
     c = np.zeros(num_variables, dtype=np.float64)
     c = np.zeros(num_variables, dtype=np.float64)
@@ -64,9 +64,9 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
     # the constraints below are tuples (A, b) such that Ax <= b
     # the constraints below are tuples (A, b) such that Ax <= b
     nonnegative_weights = -np.eye(group_size, num_variables, dtype=c.dtype), np.zeros(group_size, c.dtype)
     nonnegative_weights = -np.eye(group_size, num_variables, dtype=c.dtype), np.zeros(group_size, c.dtype)
     weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
     weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
-    coeff_per_variable = (group_size - 2.0) / np.maximum(throughputs, 10 ** -LOAD_BALANCING_LP_DECIMALS)
+    coeff_per_variable = (group_size - 2.0) / np.maximum(bandwidths, 10 ** -LOAD_BALANCING_LP_DECIMALS)
     coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1), c.dtype)])
     coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1), c.dtype)])
-    xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / throughputs[is_nonzero]
+    xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / bandwidths[is_nonzero]
     force_max_weights = np.eye(group_size, M=num_variables, dtype=c.dtype), is_nonzero.astype(c.dtype)
     force_max_weights = np.eye(group_size, M=num_variables, dtype=c.dtype), is_nonzero.astype(c.dtype)
 
 
     A, b = list(map(np.concatenate, zip(nonnegative_weights, weights_sum_to_one, xi_is_maximum, force_max_weights)))
     A, b = list(map(np.concatenate, zip(nonnegative_weights, weights_sum_to_one, xi_is_maximum, force_max_weights)))
@@ -79,7 +79,7 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
             peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
             peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
         peer_scores = np.round(peer_scores, LOAD_BALANCING_LP_DECIMALS)
         peer_scores = np.round(peer_scores, LOAD_BALANCING_LP_DECIMALS)
     else:
     else:
-        logger.error(f"Failed to solve load-balancing for bandwidths {throughputs}.")
+        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}.")
         peer_scores = np.ones(group_size, c.dtype)
         peer_scores = np.ones(group_size, c.dtype)
 
 
     return peer_scores[np.argsort(permutation)]
     return peer_scores[np.argsort(permutation)]

+ 1 - 7
hivemind/dht/__init__.py

@@ -26,7 +26,6 @@ from multiaddr import Multiaddr
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.dht.routing import DHTKey, DHTValue, Subkey
 from hivemind.dht.routing import DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.p2p import P2P
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -40,9 +39,6 @@ class DHT(mp.Process):
     * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
     * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
     * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
     * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
 
 
-    :param p2p: instance of hivemind.p2p.P2P that will be used for communication.
-      If None, DHTNode will create and manage its own P2P instance with given initial_peers and
-      parameters from ``kwargs``
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
@@ -60,7 +56,6 @@ class DHT(mp.Process):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        p2p: Optional[P2P] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         *,
         start: bool,
         start: bool,
@@ -70,9 +65,9 @@ class DHT(mp.Process):
         shutdown_timeout: float = 3,
         shutdown_timeout: float = 3,
         **kwargs,
         **kwargs,
     ):
     ):
+        self._parent_pid = os.getpid()
         super().__init__()
         super().__init__()
 
 
-        self.p2p = p2p
         if not (
         if not (
             initial_peers is None
             initial_peers is None
             or (
             or (
@@ -101,7 +96,6 @@ class DHT(mp.Process):
 
 
             async def _run():
             async def _run():
                 self._node = await DHTNode.create(
                 self._node = await DHTNode.create(
-                    p2p=self.p2p,
                     initial_peers=self.initial_peers,
                     initial_peers=self.initial_peers,
                     num_workers=self.max_workers or 1,
                     num_workers=self.max_workers or 1,
                     record_validator=self._record_validator,
                     record_validator=self._record_validator,

+ 4 - 4
hivemind/dht/node.py

@@ -114,7 +114,7 @@ class DHTNode:
         chunk_size: int = 16,
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
         backoff_rate: float = 2.0,
-        listen: bool = True,
+        client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
         validate: bool = True,
         validate: bool = True,
@@ -154,8 +154,8 @@ class DHTNode:
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
-          if False, this node will refuse any incoming request, effectively being only a "client"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+          if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
           for a given authorization protocol
           for a given authorization protocol
@@ -203,7 +203,7 @@ class DHTNode:
             wait_timeout,
             wait_timeout,
             parallel_rpc,
             parallel_rpc,
             cache_size,
             cache_size,
-            listen,
+            client_mode,
             record_validator,
             record_validator,
             authorizer,
             authorizer,
         )
         )

+ 6 - 6
hivemind/dht/protocol.py

@@ -43,7 +43,7 @@ class DHTProtocol(ServicerBase):
         wait_timeout: float,
         wait_timeout: float,
         parallel_rpc: Optional[int] = None,
         parallel_rpc: Optional[int] = None,
         cache_size: Optional[int] = None,
         cache_size: Optional[int] = None,
-        listen=True,
+        client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
     ) -> DHTProtocol:
     ) -> DHTProtocol:
@@ -66,15 +66,15 @@ class DHTProtocol(ServicerBase):
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float("inf"))
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float("inf"))
-        self.listen = listen
+        self.client_mode = client_mode
         self.record_validator = record_validator
         self.record_validator = record_validator
         self.authorizer = authorizer
         self.authorizer = authorizer
 
 
-        if listen:
+        if not client_mode:
             await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
             await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
 
 
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
-        else:  # client-only mode
+        else:
             # note: use empty node_info so peers won't add you to their routing tables
             # note: use empty node_info so peers won't add you to their routing tables
             self.node_info = dht_pb2.NodeInfo()
             self.node_info = dht_pb2.NodeInfo()
         return self
         return self
@@ -95,7 +95,7 @@ class DHTProtocol(ServicerBase):
         :param peer: peer ID to ping
         :param peer: peer ID to ping
         :param validate: if True, validates that node's peer_id is available
         :param validate: if True, validates that node's peer_id is available
         :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
         :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
-        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
+        :note: if DHTProtocol was created with client_mode=False, also request peer to add you to his routing table
 
 
         :return: node's DHTID, if peer responded and decided to send his node_id
         :return: node's DHTID, if peer responded and decided to send his node_id
         """
         """
@@ -112,7 +112,7 @@ class DHTProtocol(ServicerBase):
 
 
         if responded and validate:
         if responded and validate:
             try:
             try:
-                if self.listen and not response.available:
+                if not self.client_mode and not response.available:
                     raise ValidationError(
                     raise ValidationError(
                         f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
                         f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
                     )
                     )

+ 1 - 1
hivemind/optim/adaptive.py

@@ -29,6 +29,6 @@ class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
             average_opt_statistics=average_opt_statistics,
             average_opt_statistics=average_opt_statistics,
             prefix=f"{self.prefix}_averaging",
             prefix=f"{self.prefix}_averaging",
             allreduce_timeout=self.averaging_timeout,
             allreduce_timeout=self.averaging_timeout,
-            listen=not self.client_mode,
+            client_mode=self.client_mode,
             **kwargs,
             **kwargs,
         )
         )

+ 2 - 2
hivemind/optim/collaborative.py

@@ -167,7 +167,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             average_gradients=True,
             average_gradients=True,
             prefix=f"{self.prefix}_averaging",
             prefix=f"{self.prefix}_averaging",
             allreduce_timeout=self.averaging_timeout,
             allreduce_timeout=self.averaging_timeout,
-            listen=not self.client_mode,
+            client_mode=self.client_mode,
             **kwargs,
             **kwargs,
         )
         )
 
 
@@ -359,7 +359,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     samples_accumulated=self.local_samples_accumulated,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,
                     samples_per_second=self.performance_ema.samples_per_second,
                     time=current_time,
                     time=current_time,
-                    client_mode=not self.averager.listen,
+                    client_mode=self.averager.client_mode,
                 )
                 )
 
 
             self.dht.store(
             self.dht.store(

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -70,7 +70,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
     _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
 
-    SOFT_UPDATE_TIMEOUT = 0.1  # seconds spent awaiting status update before warning is printed
+    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
     HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
     HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
 
 
     def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
     def __init__(self, *, synchronize: bool = True, use_lock: bool = True):

+ 13 - 13
tests/test_averaging.py

@@ -76,7 +76,7 @@ def _test_allreduce_once(n_clients, n_aux):
             target_group_size=4,
             target_group_size=4,
             averaging_expiration=15,
             averaging_expiration=15,
             prefix="mygroup",
             prefix="mygroup",
-            listen=mode != AveragingMode.CLIENT,
+            client_mode=mode == AveragingMode.CLIENT,
             listen_on="127.0.0.1:*",
             listen_on="127.0.0.1:*",
             auxiliary=mode == AveragingMode.AUX,
             auxiliary=mode == AveragingMode.AUX,
             start=True,
             start=True,
@@ -121,8 +121,8 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht = hivemind.DHT(start=True)
     dht = hivemind.DHT(start=True)
 
 
     n_peers = 4
     n_peers = 4
-    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
-    random.shuffle(should_listen)
+    client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
+    random.shuffle(client_modes)
 
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
@@ -135,11 +135,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             target_group_size=4,
             target_group_size=4,
             averaging_expiration=15,
             averaging_expiration=15,
             prefix="mygroup",
             prefix="mygroup",
-            listen=listen,
+            client_mode=client_mode,
             listen_on="127.0.0.1:*",
             listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
-        for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)
+        for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
     ]
     ]
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     reference = [
     reference = [
@@ -180,7 +180,7 @@ def test_allreduce_compression():
             [x.clone() for x in tensors1],
             [x.clone() for x in tensors1],
             dht=dht,
             dht=dht,
             compression_type=compression_type_pair,
             compression_type=compression_type_pair,
-            listen=False,
+            client_mode=True,
             target_group_size=2,
             target_group_size=2,
             prefix="mygroup",
             prefix="mygroup",
             start=True,
             start=True,
@@ -306,16 +306,16 @@ def test_allgather():
     dht.shutdown()
     dht.shutdown()
 
 
 
 
-def get_cost(vector_size, partitions, throughputs):
+def get_cost(vector_size, partitions, bandwidths):
     return max(
     return max(
-        (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
+        (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(bandwidths[i], 1e-9)
         for i in range(len(partitions))
         for i in range(len(partitions))
     )
     )
 
 
 
 
-def check_optimality(vector_size, throughputs, ref_partitions):
-    partitions = list(load_balance_peers(vector_size, throughputs))
-    assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
+def check_optimality(vector_size, bandwidths, ref_partitions):
+    partitions = list(load_balance_peers(vector_size, bandwidths))
+    assert get_cost(vector_size, partitions, bandwidths) <= get_cost(vector_size, ref_partitions, bandwidths)
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -342,9 +342,9 @@ def test_load_balancing():
         vector_size = np.random.randint(1, 1024 ** 3)
         vector_size = np.random.randint(1, 1024 ** 3)
         num_peers = np.random.randint(1, 256)
         num_peers = np.random.randint(1, 256)
         scale = 1e-9 + np.random.rand() * 1e5
         scale = 1e-9 + np.random.rand() * 1e5
-        throughputs = np.random.rand(num_peers) * scale + 1e-6
+        bandwidths = np.random.rand(num_peers) * scale + 1e-6
         min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
         min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
-        assignment = load_balance_peers(vector_size, throughputs, min_size)
+        assignment = load_balance_peers(vector_size, bandwidths, min_size)
         assert np.sum(assignment) == vector_size
         assert np.sum(assignment) == vector_size
         assert np.min(assignment) >= 0
         assert np.min(assignment) >= 0
 
 

+ 1 - 1
tests/test_dht.py

@@ -102,7 +102,7 @@ async def test_dht_get_visible_maddrs():
 
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(p2p, start=True)
+    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
 
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
     dht.shutdown()
     dht.shutdown()

+ 6 - 5
tests/test_dht_node.py

@@ -77,11 +77,12 @@ def test_dht_protocol():
     peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
     peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
 
 
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
-    for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
+        peer_id = DHTID.generate()
         p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
         p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
         protocol = loop.run_until_complete(
         protocol = loop.run_until_complete(
             DHTProtocol.create(
             DHTProtocol.create(
-                p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen
+                p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
             )
             )
         )
         )
         logger.info(f"Self id={protocol.node_id}")
         logger.info(f"Self id={protocol.node_id}")
@@ -150,7 +151,7 @@ def test_dht_protocol():
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
 
-        if listen:
+        if not client_mode:
             loop.run_until_complete(p2p.shutdown())
             loop.run_until_complete(p2p.shutdown())
 
 
     peer1_proc.terminate()
     peer1_proc.terminate()
@@ -166,7 +167,7 @@ def test_empty_table():
     p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
     p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
     protocol = loop.run_until_complete(
     protocol = loop.run_until_complete(
         DHTProtocol.create(
         DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False
+            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
         )
         )
     )
     )
 
 
@@ -353,7 +354,7 @@ async def test_dhtnode_caching(T=0.05):
     node1 = await DHTNode.create(
     node1 = await DHTNode.create(
         initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
         initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
         cache_refresh_before_expiry=5 * T,
         cache_refresh_before_expiry=5 * T,
-        listen=False,
+        client_mode=True,
         reuse_get_requests=False,
         reuse_get_requests=False,
     )
     )
     await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)