瀏覽代碼

Merge remote-tracking branch 'origin/master' into Mike-2

Michael Diskin 3 年之前
父節點
當前提交
482ee60963

+ 14 - 14
examples/albert/README.md

@@ -20,15 +20,16 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 
 - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
 - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
   & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
   & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-- Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
-- `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
-  due to naming conventions.
-- `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the
-  same project name.
+- Run `python run_training_monitor.py --experiment_prefix YOUR_EXPERIMENT_NAME --wandb_project YOUR_WANDB_PROJECT`
+
+  - `YOUR_EXPERIMENT_NAME` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
+    due to naming conventions.
+  - `YOUR_WANDB_PROJECT` is a name of wandb project used to track training metrics. Multiple experiments can have the
+    same project name.
 
 
 ```
 ```
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
 $ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet, 
+[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet,
 use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Tracking run with wandb version 0.10.32
@@ -56,8 +57,8 @@ To join the collaboration with a GPU trainer,
 - Run:
 - Run:
   ```bash
   ```bash
   python run_trainer.py \
   python run_trainer.py \
-  --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
-  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+      --experiment_prefix YOUR_EXPERIMENT_NAME --initial_peers ONE_OR_MORE_PEERS \
+      --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
   ```
   ```
 
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
@@ -135,7 +136,7 @@ incoming connections (e.g. when in colab or behind a firewall), add `--client_mo
 below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
 below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
-recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. 
+recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
 
 
 The example trainer supports
 The example trainer supports
 multiple GPUs via DataParallel. However, using advanced distributed training strategies (
 multiple GPUs via DataParallel. However, using advanced distributed training strategies (
@@ -155,7 +156,7 @@ collaborative experiment. Here's how to best use them:
 - Most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example
 - Most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example
   below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop
   below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop
   with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we
   with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we
-  tested this code on preemptible 
+  tested this code on preemptible
   [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/)
   [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/)
   nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
   nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
 - You can create starter notebooks to make it more convenient for collaborators to join your training
 - You can create starter notebooks to make it more convenient for collaborators to join your training
@@ -169,10 +170,9 @@ Here's an example of a full trainer script for Google Colab:
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -
 !curl -L YOUR_HOSTED_DATA | tar xzf -
 !ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
 !ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
- --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
- --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
- --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
- --experiment_prefix EXPERIMENT_NAME_HERE --seed 42
+    --experiment_prefix YOUR_EXPERIMENT_NAME --initial_peers ONE_OR_MORE_PEERS \
+    --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
+    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 ```
 
 
 ### Using IPFS
 ### Using IPFS

+ 1 - 1
examples/albert/arguments.py

@@ -49,7 +49,7 @@ class AveragerArguments:
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     )
     averaging_timeout: float = field(
     averaging_timeout: float = field(
-        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     )
     min_refresh_period: float = field(
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}

+ 20 - 16
hivemind/averaging/allreduce.py

@@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase):
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
-    :param tensors: local tensors that should be averaged with groupmates
+    :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
     """
@@ -56,9 +55,9 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         prefix: Optional[str],
         group_id: GroupID,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
+        weight: Optional[float] = None,
         ordered_peer_ids: Sequence[PeerID],
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
-        weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
         **kwargs,
@@ -73,23 +72,24 @@ class AllReduceRunner(ServicerBase):
         self._prefix = prefix
         self._prefix = prefix
 
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
-        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
+        assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
-        for mode, frac, weight in zip(modes, peer_fractions, weights):
+        for mode, frac in zip(modes, peer_fractions):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
-            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
 
+        if weight is None:
+            weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
+        self.weight = weight
+
         self._future = asyncio.Future()
         self._future = asyncio.Future()
 
 
-        self.sender_peer_ids, self.sender_weights = [], []
-        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
+        self.sender_peer_ids = []
+        for peer_id, mode in zip(self.ordered_peer_ids, modes):
             if mode != AveragingMode.AUX:
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
                 self.sender_peer_ids.append(peer_id)
-                self.sender_weights.append(weight)
 
 
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
@@ -97,7 +97,6 @@ class AllReduceRunner(ServicerBase):
         self.tensor_part_reducer = TensorPartReducer(
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
             len(self.sender_peer_ids),
-            self.sender_weights,
         )
         )
 
 
     def __repr__(self):
     def __repr__(self):
@@ -149,7 +148,9 @@ class AllReduceRunner(ServicerBase):
         if peer_id == self.peer_id:
         if peer_id == self.peer_id:
             sender_index = self.sender_peer_ids.index(peer_id)
             sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=self.weight
+                )
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
 
         else:
         else:
@@ -180,9 +181,10 @@ class AllReduceRunner(ServicerBase):
             code=averaging_pb2.PART_FOR_AVERAGING,
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             group_id=self.group_id,
             tensor_part=first_part,
             tensor_part=first_part,
+            weight=self.weight,
         )
         )
         async for part in parts_aiter:
         async for part in parts_aiter:
-            yield averaging_pb2.AveragingData(tensor_part=part)
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
 
 
     async def rpc_aggregate_part(
     async def rpc_aggregate_part(
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
@@ -219,14 +221,16 @@ class AllReduceRunner(ServicerBase):
 
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, part_compression) in aenumerate(
+        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
             amap_in_executor(
             amap_in_executor(
-                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
                 max_prefetch=self.tensor_part_container.prefetch,
             )
             )
         ):
         ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+            averaged_part = await self.tensor_part_reducer.accumulate_part(
+                sender_index, part_index, tensor_part, weight=weight
+            )
 
 
             serialized_delta = await loop.run_in_executor(
             serialized_delta = await loop.run_in_executor(
                 None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
                 None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)

+ 5 - 4
hivemind/averaging/averager.py

@@ -367,7 +367,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             while not future.done():
             while not future.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
+                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                         timeout=timeout, data_for_gather=data_for_gather
                     )
                     )
@@ -376,7 +376,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                     future.set_result(
                     future.set_result(
                         await asyncio.wait_for(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            self._run_allreduce(
+                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                            ),
                             timeout=self._allreduce_timeout,
                             timeout=self._allreduce_timeout,
                         )
                         )
                     )
                     )
@@ -414,7 +416,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
-            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
 
 
@@ -435,7 +437,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     tensors=local_tensors,
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
-                    weights=weights,
                     gathered=user_gathered,
                     gathered=user_gathered,
                     modes=modes,
                     modes=modes,
                     **kwargs,
                     **kwargs,

+ 6 - 8
hivemind/averaging/partition.py

@@ -167,15 +167,11 @@ class TensorPartReducer:
     Auxiliary data structure responsible for running asynchronous all-reduce
     Auxiliary data structure responsible for running asynchronous all-reduce
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
-    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
     :note: even if local peer is not sending data, local parts will be used for shape information
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
     """
 
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
-        self.weights = tuple(weights or (1 for _ in range(num_senders)))
-        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
-        assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
@@ -197,7 +193,9 @@ class TensorPartReducer:
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
         self.denominator = 0.0
 
 
-    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+    async def accumulate_part(
+        self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
+    ) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
@@ -211,9 +209,9 @@ class TensorPartReducer:
 
 
         current_part_future = self.current_part_future
         current_part_future = self.current_part_future
 
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
+        self.accumulator.add_(tensor_part, alpha=weight)
         self.current_part_accumulated_from += 1
         self.current_part_accumulated_from += 1
+        self.denominator += weight
 
 
         assert self.current_part_accumulated_from <= self.num_senders
         assert self.current_part_accumulated_from <= self.num_senders
         if self.current_part_accumulated_from == self.num_senders:
         if self.current_part_accumulated_from == self.num_senders:

+ 2 - 3
hivemind/moe/client/expert.py

@@ -1,4 +1,3 @@
-import pickle
 from typing import Any, Dict, Optional, Tuple
 from typing import Any, Dict, Optional, Tuple
 
 
 import torch
 import torch
@@ -7,7 +6,7 @@ from torch.autograd.function import once_differentiable
 
 
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -60,7 +59,7 @@ class RemoteExpert(nn.Module):
     def info(self):
     def info(self):
         if self._info is None:
         if self._info is None:
             outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
             outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = pickle.loads(outputs.serialized_info)
+            self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
         return self._info
 
 
     def extra_repr(self):
     def extra_repr(self):

+ 2 - 3
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,5 @@
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
-import pickle
 from typing import Dict
 from typing import Dict
 
 
 import grpc
 import grpc
@@ -9,7 +8,7 @@ import torch
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, get_logger, nested_flatten
+from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 
@@ -61,7 +60,7 @@ class ConnectionHandler(mp.context.ForkProcess):
             logger.debug("Caught KeyboardInterrupt, shutting down")
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
-        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
 
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]

+ 39 - 19
hivemind/optim/collaborative.py

@@ -149,7 +149,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         self.training_progress_key = f"{self.prefix}_progress"
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
         self.last_step_time = None
 
 
@@ -181,6 +181,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
     @property
     @property
     def is_synchronized(self) -> bool:
     def is_synchronized(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step
+
+    @property
+    def is_within_tolerance(self) -> bool:
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
 
 
     def is_alive(self) -> bool:
     def is_alive(self) -> bool:
@@ -197,7 +201,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
                     continue
 
 
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.update_scheduler()
             self.update_scheduler()
 
 
@@ -226,10 +230,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.batch_size_per_step = batch_size
             self.batch_size_per_step = batch_size
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
 
-        if not self.is_synchronized:
+        if not self.is_synchronized and not self.is_within_tolerance:
             logger.log(self.status_loglevel, "Peer is out of sync.")
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             self.load_state_from_peers()
             return
             return
+        elif not self.is_synchronized and self.is_within_tolerance:
+            self.averager.local_step = self.collaboration_state.optimizer_step
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
             logger.warning(
@@ -241,7 +248,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
-            self.local_steps_accumulated += 1
+            self.local_updates_accumulated += 1
             self.performance_ema.update(num_processed=batch_size)
             self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
             self.should_report_progress.set()
 
 
@@ -249,25 +256,31 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
             return
 
 
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        if not self.is_synchronized:
-            self.load_state_from_peers()
-            return
-
         with self.performance_ema.pause(), self.lock_collaboration_state:
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            self.collaboration_state = self._fetch_state()
+            self.collaboration_state_updated.set()
+
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
             current_step, group_info = self.averager.local_step, None
             current_step, group_info = self.averager.local_step, None
 
 
             if self.collaboration_state.num_peers > 1:
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 try:
                 try:
-                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                    group_info = self.averager.step(
+                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
+                    )
                     if group_info:
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                        # update our current step if we averaged with another peer that was at a more recent step
+                        for peer, peer_step in group_info.items():
+                            if isinstance(peer_step, int):
+                                current_step = max(current_step, peer_step)
+                            else:
+                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
+
                 except BaseException as e:
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
 
@@ -279,7 +292,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
             self.opt.step()
             self.opt.step()
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.collaboration_state_updated.set()
@@ -304,12 +317,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.collaboration_state_updated.set()
         self.collaboration_state_updated.set()
 
 
         with self.lock_collaboration_state:
         with self.lock_collaboration_state:
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             current_step, group_info = self.averager.local_step, None
             current_step, group_info = self.averager.local_step, None
+
             try:
             try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
                 if group_info:
                 if group_info:
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                    # update our current step if we averaged with another peer that was at a more recent step
+                    for peer, peer_step in group_info.items():
+                        if isinstance(peer_step, int):
+                            current_step = max(current_step, peer_step)
+                        else:
+                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
             except BaseException as e:
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
 
@@ -412,9 +432,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         if not isinstance(response, dict) or len(response) == 0:
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            local_eta_next_step = (
-                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
-            )
+            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
+            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
+
             return CollaborationState(
             return CollaborationState(
                 self.local_step,
                 self.local_step,
                 self.local_samples_accumulated,
                 self.local_samples_accumulated,

+ 1 - 1
hivemind/proto/averaging.proto

@@ -45,7 +45,7 @@ message AveragingData {
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
-  bytes metadata = 5;       // reserved user-extendable metadata
+  double weight = 5;        // peers will be averaged in proportion to these weights
 }
 }
 
 
 message DownloadRequest {}
 message DownloadRequest {}

+ 50 - 1
hivemind/utils/tensor_descr.py

@@ -8,6 +8,7 @@ import numpy as np
 import torch
 import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.serializer import MSGPackSerializer
 
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
@@ -52,6 +53,18 @@ class TensorDescriptor(DescriptorBase):
         return torch.empty(**properties)
         return torch.empty(**properties)
 
 
 
 
+def _str_to_torch_type(name: str, torch_type: type):
+    try:
+        value = getattr(torch, name.split(".")[-1])
+    except AttributeError:
+        raise ValueError(f"Invalid dtype: torch has no attribute {name}")
+    if not isinstance(value, torch_type):
+        raise ValueError(f"Invalid dtype: expected {torch_type}, got: {type(value)}")
+
+    return value
+
+
+@MSGPackSerializer.ext_serializable(0x51)
 @dataclass(repr=True, frozen=True)
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
 class BatchTensorDescriptor(TensorDescriptor):
     """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
     """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
@@ -70,13 +83,49 @@ class BatchTensorDescriptor(TensorDescriptor):
             device=tensor.device,
             device=tensor.device,
             requires_grad=tensor.requires_grad,
             requires_grad=tensor.requires_grad,
             pin_memory=_safe_check_pinned(tensor),
             pin_memory=_safe_check_pinned(tensor),
-            compression=compression if tensor.is_floating_point() else CompressionType.NONE
+            compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
         )
 
 
     def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
     def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
 
+    def packb(self) -> bytes:
+        obj_dict = asdict(self)
+
+        obj_dict["dtype"] = str(self.dtype) if self.dtype is not None else None
+        obj_dict["layout"] = str(self.layout) if self.layout is not None else None
+
+        device = obj_dict.pop("device")
+        device_type, device_index = (device.type, device.index) if device is not None else (None, None)
+        obj_dict.update(
+            device_type=device_type,
+            device_index=device_index,
+        )
+
+        return MSGPackSerializer.dumps(obj_dict)
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> BatchTensorDescriptor:
+        obj_dict = MSGPackSerializer.loads(raw)
+
+        if obj_dict["dtype"] is not None:
+            obj_dict["dtype"] = _str_to_torch_type(obj_dict["dtype"], torch.dtype)
+
+        if obj_dict["layout"] is not None:
+            obj_dict["layout"] = _str_to_torch_type(obj_dict["layout"], torch.layout)
+
+        if obj_dict["device_type"] is not None:
+            obj_dict["device"] = torch.device(obj_dict["device_type"], obj_dict["device_index"])
+        else:
+            obj_dict["device"] = None
+
+        del obj_dict["device_type"], obj_dict["device_index"]
+
+        size = obj_dict.pop("size")[1:]
+
+        return BatchTensorDescriptor(*size, **obj_dict)
+
 
 
 def _safe_check_pinned(tensor: torch.Tensor) -> bool:
 def _safe_check_pinned(tensor: torch.Tensor) -> bool:
     """Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error."""
     """Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error."""

+ 2 - 2
tests/test_allreduce.py

@@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
 
     allreduce_protocols = []
     allreduce_protocols = []
-    for p2p in p2ps:
+    for i, p2p in enumerate(p2ps):
         allreduce_protocol = AllReduceRunner(
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             p2p=p2p,
             servicer_type=AllReduceRunner,
             servicer_type=AllReduceRunner,
@@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             ordered_peer_ids=peers,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             modes=peer_modes,
-            weights=averaging_weights,
+            weight=averaging_weights[i],
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,
         )
         )
         await allreduce_protocol.add_p2p_handlers(p2p)
         await allreduce_protocol.add_p2p_handlers(p2p)

+ 16 - 1
tests/test_util_modules.py

@@ -13,7 +13,7 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
     aenumerate,
     aenumerate,
@@ -521,3 +521,18 @@ async def test_cancel_and_wait():
     await asyncio.sleep(0.05)
     await asyncio.sleep(0.05)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_error)
     assert not await cancel_and_wait(task_with_error)
+
+
+def test_batch_tensor_descriptor_msgpack():
+    tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
+    tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
+
+    assert (
+        tensor_descr.size == tensor_descr_roundtrip.size
+        and tensor_descr.dtype == tensor_descr_roundtrip.dtype
+        and tensor_descr.layout == tensor_descr_roundtrip.layout
+        and tensor_descr.device == tensor_descr_roundtrip.device
+        and tensor_descr.requires_grad == tensor_descr_roundtrip.requires_grad
+        and tensor_descr.pin_memory == tensor_descr.pin_memory
+        and tensor_descr.compression == tensor_descr.compression
+    )