Ver Fonte

Merge remote-tracking branch 'origin/master' into Mike-2

Michael Diskin há 3 anos atrás
pai
commit
482ee60963

+ 14 - 14
examples/albert/README.md

@@ -20,15 +20,16 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
   & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-- Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
-- `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
-  due to naming conventions.
-- `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the
-  same project name.
+- Run `python run_training_monitor.py --experiment_prefix YOUR_EXPERIMENT_NAME --wandb_project YOUR_WANDB_PROJECT`
+
+  - `YOUR_EXPERIMENT_NAME` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
+    due to naming conventions.
+  - `YOUR_WANDB_PROJECT` is a name of wandb project used to track training metrics. Multiple experiments can have the
+    same project name.
 
 ```
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet, 
+[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet,
 use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
@@ -56,8 +57,8 @@ To join the collaboration with a GPU trainer,
 - Run:
   ```bash
   python run_trainer.py \
-  --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
-  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+      --experiment_prefix YOUR_EXPERIMENT_NAME --initial_peers ONE_OR_MORE_PEERS \
+      --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
   ```
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
@@ -135,7 +136,7 @@ incoming connections (e.g. when in colab or behind a firewall), add `--client_mo
 below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
-recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. 
+recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
 
 The example trainer supports
 multiple GPUs via DataParallel. However, using advanced distributed training strategies (
@@ -155,7 +156,7 @@ collaborative experiment. Here's how to best use them:
 - Most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example
   below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop
   with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we
-  tested this code on preemptible 
+  tested this code on preemptible
   [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/)
   nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
 - You can create starter notebooks to make it more convenient for collaborators to join your training
@@ -169,10 +170,9 @@ Here's an example of a full trainer script for Google Colab:
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -
 !ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
- --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
- --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
- --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
- --experiment_prefix EXPERIMENT_NAME_HERE --seed 42
+    --experiment_prefix YOUR_EXPERIMENT_NAME --initial_peers ONE_OR_MORE_PEERS \
+    --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
+    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 
 ### Using IPFS

+ 1 - 1
examples/albert/arguments.py

@@ -49,7 +49,7 @@ class AveragerArguments:
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     averaging_timeout: float = field(
-        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}

+ 20 - 16
hivemind/averaging/allreduce.py

@@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase):
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
-    :param tensors: local tensors that should be averaged with groupmates
+    :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
@@ -56,9 +55,9 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
+        weight: Optional[float] = None,
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
-        weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
@@ -73,23 +72,24 @@ class AllReduceRunner(ServicerBase):
         self._prefix = prefix
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
-        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
+        assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
-        for mode, frac, weight in zip(modes, peer_fractions, weights):
+        for mode, frac in zip(modes, peer_fractions):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
-            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
+        if weight is None:
+            weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
+        self.weight = weight
+
         self._future = asyncio.Future()
 
-        self.sender_peer_ids, self.sender_weights = [], []
-        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
+        self.sender_peer_ids = []
+        for peer_id, mode in zip(self.ordered_peer_ids, modes):
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
-                self.sender_weights.append(weight)
 
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
@@ -97,7 +97,6 @@ class AllReduceRunner(ServicerBase):
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
-            self.sender_weights,
         )
 
     def __repr__(self):
@@ -149,7 +148,9 @@ class AllReduceRunner(ServicerBase):
         if peer_id == self.peer_id:
             sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=self.weight
+                )
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
@@ -180,9 +181,10 @@ class AllReduceRunner(ServicerBase):
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             tensor_part=first_part,
+            weight=self.weight,
         )
         async for part in parts_aiter:
-            yield averaging_pb2.AveragingData(tensor_part=part)
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
 
     async def rpc_aggregate_part(
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
@@ -219,14 +221,16 @@ class AllReduceRunner(ServicerBase):
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
         loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, part_compression) in aenumerate(
+        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
             amap_in_executor(
-                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
             )
         ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+            averaged_part = await self.tensor_part_reducer.accumulate_part(
+                sender_index, part_index, tensor_part, weight=weight
+            )
 
             serialized_delta = await loop.run_in_executor(
                 None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)

+ 5 - 4
hivemind/averaging/averager.py

@@ -367,7 +367,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             while not future.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
+                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                     )
@@ -376,7 +376,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                     future.set_result(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            self._run_allreduce(
+                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                            ),
                             timeout=self._allreduce_timeout,
                         )
                     )
@@ -414,7 +416,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
 
@@ -435,7 +437,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    weights=weights,
                     gathered=user_gathered,
                     modes=modes,
                     **kwargs,

+ 6 - 8
hivemind/averaging/partition.py

@@ -167,15 +167,11 @@ class TensorPartReducer:
     Auxiliary data structure responsible for running asynchronous all-reduce
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
-    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
-        self.weights = tuple(weights or (1 for _ in range(num_senders)))
-        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
-        assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
@@ -197,7 +193,9 @@ class TensorPartReducer:
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
 
-    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+    async def accumulate_part(
+        self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
+    ) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
@@ -211,9 +209,9 @@ class TensorPartReducer:
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
+        self.accumulator.add_(tensor_part, alpha=weight)
         self.current_part_accumulated_from += 1
+        self.denominator += weight
 
         assert self.current_part_accumulated_from <= self.num_senders
         if self.current_part_accumulated_from == self.num_senders:

+ 2 - 3
hivemind/moe/client/expert.py

@@ -1,4 +1,3 @@
-import pickle
 from typing import Any, Dict, Optional, Tuple
 
 import torch
@@ -7,7 +6,7 @@ from torch.autograd.function import once_differentiable
 
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
 from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -60,7 +59,7 @@ class RemoteExpert(nn.Module):
     def info(self):
         if self._info is None:
             outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = pickle.loads(outputs.serialized_info)
+            self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
 
     def extra_repr(self):

+ 2 - 3
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,5 @@
 import multiprocessing as mp
 import os
-import pickle
 from typing import Dict
 
 import grpc
@@ -9,7 +8,7 @@ import torch
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, get_logger, nested_flatten
+from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
@@ -61,7 +60,7 @@ class ConnectionHandler(mp.context.ForkProcess):
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
-        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]

+ 39 - 19
hivemind/optim/collaborative.py

@@ -149,7 +149,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -181,6 +181,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @property
     def is_synchronized(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step
+
+    @property
+    def is_within_tolerance(self) -> bool:
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
 
     def is_alive(self) -> bool:
@@ -197,7 +201,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
 
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
@@ -226,10 +230,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.batch_size_per_step = batch_size
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
-        if not self.is_synchronized:
+        if not self.is_synchronized and not self.is_within_tolerance:
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             return
+        elif not self.is_synchronized and self.is_within_tolerance:
+            self.averager.local_step = self.collaboration_state.optimizer_step
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
@@ -241,7 +248,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
-            self.local_steps_accumulated += 1
+            self.local_updates_accumulated += 1
             self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
 
@@ -249,25 +256,31 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
 
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        if not self.is_synchronized:
-            self.load_state_from_peers()
-            return
-
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            self.collaboration_state = self._fetch_state()
+            self.collaboration_state_updated.set()
+
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 try:
-                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                    group_info = self.averager.step(
+                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
+                    )
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                        # update our current step if we averaged with another peer that was at a more recent step
+                        for peer, peer_step in group_info.items():
+                            if isinstance(peer_step, int):
+                                current_step = max(current_step, peer_step)
+                            else:
+                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
+
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -279,7 +292,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             self.opt.step()
             self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
@@ -304,12 +317,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             current_step, group_info = self.averager.local_step, None
+
             try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
                 if group_info:
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                    # update our current step if we averaged with another peer that was at a more recent step
+                    for peer, peer_step in group_info.items():
+                        if isinstance(peer_step, int):
+                            current_step = max(current_step, peer_step)
+                        else:
+                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -412,9 +432,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            local_eta_next_step = (
-                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
-            )
+            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
+            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
+
             return CollaborationState(
                 self.local_step,
                 self.local_samples_accumulated,

+ 1 - 1
hivemind/proto/averaging.proto

@@ -45,7 +45,7 @@ message AveragingData {
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
-  bytes metadata = 5;       // reserved user-extendable metadata
+  double weight = 5;        // peers will be averaged in proportion to these weights
 }
 
 message DownloadRequest {}

+ 50 - 1
hivemind/utils/tensor_descr.py

@@ -8,6 +8,7 @@ import numpy as np
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.serializer import MSGPackSerializer
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
@@ -52,6 +53,18 @@ class TensorDescriptor(DescriptorBase):
         return torch.empty(**properties)
 
 
+def _str_to_torch_type(name: str, torch_type: type):
+    try:
+        value = getattr(torch, name.split(".")[-1])
+    except AttributeError:
+        raise ValueError(f"Invalid dtype: torch has no attribute {name}")
+    if not isinstance(value, torch_type):
+        raise ValueError(f"Invalid dtype: expected {torch_type}, got: {type(value)}")
+
+    return value
+
+
+@MSGPackSerializer.ext_serializable(0x51)
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
     """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
@@ -70,13 +83,49 @@ class BatchTensorDescriptor(TensorDescriptor):
             device=tensor.device,
             requires_grad=tensor.requires_grad,
             pin_memory=_safe_check_pinned(tensor),
-            compression=compression if tensor.is_floating_point() else CompressionType.NONE
+            compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
     def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
+    def packb(self) -> bytes:
+        obj_dict = asdict(self)
+
+        obj_dict["dtype"] = str(self.dtype) if self.dtype is not None else None
+        obj_dict["layout"] = str(self.layout) if self.layout is not None else None
+
+        device = obj_dict.pop("device")
+        device_type, device_index = (device.type, device.index) if device is not None else (None, None)
+        obj_dict.update(
+            device_type=device_type,
+            device_index=device_index,
+        )
+
+        return MSGPackSerializer.dumps(obj_dict)
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> BatchTensorDescriptor:
+        obj_dict = MSGPackSerializer.loads(raw)
+
+        if obj_dict["dtype"] is not None:
+            obj_dict["dtype"] = _str_to_torch_type(obj_dict["dtype"], torch.dtype)
+
+        if obj_dict["layout"] is not None:
+            obj_dict["layout"] = _str_to_torch_type(obj_dict["layout"], torch.layout)
+
+        if obj_dict["device_type"] is not None:
+            obj_dict["device"] = torch.device(obj_dict["device_type"], obj_dict["device_index"])
+        else:
+            obj_dict["device"] = None
+
+        del obj_dict["device_type"], obj_dict["device_index"]
+
+        size = obj_dict.pop("size")[1:]
+
+        return BatchTensorDescriptor(*size, **obj_dict)
+
 
 def _safe_check_pinned(tensor: torch.Tensor) -> bool:
     """Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error."""

+ 2 - 2
tests/test_allreduce.py

@@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
     allreduce_protocols = []
-    for p2p in p2ps:
+    for i, p2p in enumerate(p2ps):
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             servicer_type=AllReduceRunner,
@@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
-            weights=averaging_weights,
+            weight=averaging_weights[i],
             part_size_bytes=part_size_bytes,
         )
         await allreduce_protocol.add_p2p_handlers(p2p)

+ 16 - 1
tests/test_util_modules.py

@@ -13,7 +13,7 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
@@ -521,3 +521,18 @@ async def test_cancel_and_wait():
     await asyncio.sleep(0.05)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_error)
+
+
+def test_batch_tensor_descriptor_msgpack():
+    tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
+    tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
+
+    assert (
+        tensor_descr.size == tensor_descr_roundtrip.size
+        and tensor_descr.dtype == tensor_descr_roundtrip.dtype
+        and tensor_descr.layout == tensor_descr_roundtrip.layout
+        and tensor_descr.device == tensor_descr_roundtrip.device
+        and tensor_descr.requires_grad == tensor_descr_roundtrip.requires_grad
+        and tensor_descr.pin_memory == tensor_descr.pin_memory
+        and tensor_descr.compression == tensor_descr.compression
+    )