Browse Source

Implement weights as part of the allreduce protocol, not matchmaking (#384)

* implement parts as part of the allreduce protocol, not matchmaking
* remove metadata field from AveragingData (unused)

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 years ago
parent
commit
4a9bc92cd1

+ 20 - 16
hivemind/averaging/allreduce.py

@@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase):
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
-    :param tensors: local tensors that should be averaged with groupmates
+    :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
     """
@@ -56,9 +55,9 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         prefix: Optional[str],
         group_id: GroupID,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
+        weight: Optional[float] = None,
         ordered_peer_ids: Sequence[PeerID],
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
-        weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
         **kwargs,
@@ -73,23 +72,24 @@ class AllReduceRunner(ServicerBase):
         self._prefix = prefix
         self._prefix = prefix
 
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
-        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
+        assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
-        for mode, frac, weight in zip(modes, peer_fractions, weights):
+        for mode, frac in zip(modes, peer_fractions):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
-            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
 
+        if weight is None:
+            weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
+        self.weight = weight
+
         self._future = asyncio.Future()
         self._future = asyncio.Future()
 
 
-        self.sender_peer_ids, self.sender_weights = [], []
-        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
+        self.sender_peer_ids = []
+        for peer_id, mode in zip(self.ordered_peer_ids, modes):
             if mode != AveragingMode.AUX:
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
                 self.sender_peer_ids.append(peer_id)
-                self.sender_weights.append(weight)
 
 
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
@@ -97,7 +97,6 @@ class AllReduceRunner(ServicerBase):
         self.tensor_part_reducer = TensorPartReducer(
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
             len(self.sender_peer_ids),
-            self.sender_weights,
         )
         )
 
 
     def __repr__(self):
     def __repr__(self):
@@ -149,7 +148,9 @@ class AllReduceRunner(ServicerBase):
         if peer_id == self.peer_id:
         if peer_id == self.peer_id:
             sender_index = self.sender_peer_ids.index(peer_id)
             sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=self.weight
+                )
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
 
         else:
         else:
@@ -180,9 +181,10 @@ class AllReduceRunner(ServicerBase):
             code=averaging_pb2.PART_FOR_AVERAGING,
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             group_id=self.group_id,
             tensor_part=first_part,
             tensor_part=first_part,
+            weight=self.weight,
         )
         )
         async for part in parts_aiter:
         async for part in parts_aiter:
-            yield averaging_pb2.AveragingData(tensor_part=part)
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
 
 
     async def rpc_aggregate_part(
     async def rpc_aggregate_part(
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
@@ -219,14 +221,16 @@ class AllReduceRunner(ServicerBase):
 
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, part_compression) in aenumerate(
+        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
             amap_in_executor(
             amap_in_executor(
-                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
                 max_prefetch=self.tensor_part_container.prefetch,
             )
             )
         ):
         ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+            averaged_part = await self.tensor_part_reducer.accumulate_part(
+                sender_index, part_index, tensor_part, weight=weight
+            )
 
 
             serialized_delta = await loop.run_in_executor(
             serialized_delta = await loop.run_in_executor(
                 None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
                 None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)

+ 5 - 4
hivemind/averaging/averager.py

@@ -367,7 +367,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             while not future.done():
             while not future.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
+                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                         timeout=timeout, data_for_gather=data_for_gather
                     )
                     )
@@ -376,7 +376,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                     future.set_result(
                     future.set_result(
                         await asyncio.wait_for(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            self._run_allreduce(
+                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                            ),
                             timeout=self._allreduce_timeout,
                             timeout=self._allreduce_timeout,
                         )
                         )
                     )
                     )
@@ -414,7 +416,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
-            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
 
 
@@ -435,7 +437,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     tensors=local_tensors,
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
-                    weights=weights,
                     gathered=user_gathered,
                     gathered=user_gathered,
                     modes=modes,
                     modes=modes,
                     **kwargs,
                     **kwargs,

+ 6 - 8
hivemind/averaging/partition.py

@@ -167,15 +167,11 @@ class TensorPartReducer:
     Auxiliary data structure responsible for running asynchronous all-reduce
     Auxiliary data structure responsible for running asynchronous all-reduce
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
-    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
     :note: even if local peer is not sending data, local parts will be used for shape information
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
     """
 
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
-        self.weights = tuple(weights or (1 for _ in range(num_senders)))
-        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
-        assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
@@ -197,7 +193,9 @@ class TensorPartReducer:
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
         self.denominator = 0.0
 
 
-    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+    async def accumulate_part(
+        self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
+    ) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
@@ -211,9 +209,9 @@ class TensorPartReducer:
 
 
         current_part_future = self.current_part_future
         current_part_future = self.current_part_future
 
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
+        self.accumulator.add_(tensor_part, alpha=weight)
         self.current_part_accumulated_from += 1
         self.current_part_accumulated_from += 1
+        self.denominator += weight
 
 
         assert self.current_part_accumulated_from <= self.num_senders
         assert self.current_part_accumulated_from <= self.num_senders
         if self.current_part_accumulated_from == self.num_senders:
         if self.current_part_accumulated_from == self.num_senders:

+ 1 - 1
hivemind/proto/averaging.proto

@@ -45,7 +45,7 @@ message AveragingData {
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
-  bytes metadata = 5;       // reserved user-extendable metadata
+  double weight = 5;        // peers will be averaged in proportion to these weights
 }
 }
 
 
 message DownloadRequest {}
 message DownloadRequest {}

+ 2 - 2
tests/test_allreduce.py

@@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
 
     allreduce_protocols = []
     allreduce_protocols = []
-    for p2p in p2ps:
+    for i, p2p in enumerate(p2ps):
         allreduce_protocol = AllReduceRunner(
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             p2p=p2p,
             servicer_type=AllReduceRunner,
             servicer_type=AllReduceRunner,
@@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             ordered_peer_ids=peers,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             modes=peer_modes,
-            weights=averaging_weights,
+            weight=averaging_weights[i],
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,
         )
         )
         await allreduce_protocol.add_p2p_handlers(p2p)
         await allreduce_protocol.add_p2p_handlers(p2p)