Переглянути джерело

Implement weights as part of the allreduce protocol, not matchmaking (#384)

* implement parts as part of the allreduce protocol, not matchmaking
* remove metadata field from AveragingData (unused)

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 роки тому
батько
коміт
4a9bc92cd1

+ 20 - 16
hivemind/averaging/allreduce.py

@@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase):
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
-    :param tensors: local tensors that should be averaged with groupmates
+    :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
@@ -56,9 +55,9 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
+        weight: Optional[float] = None,
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
-        weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
@@ -73,23 +72,24 @@ class AllReduceRunner(ServicerBase):
         self._prefix = prefix
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
-        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
+        assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
-        for mode, frac, weight in zip(modes, peer_fractions, weights):
+        for mode, frac in zip(modes, peer_fractions):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
-            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
+        if weight is None:
+            weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
+        self.weight = weight
+
         self._future = asyncio.Future()
 
-        self.sender_peer_ids, self.sender_weights = [], []
-        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
+        self.sender_peer_ids = []
+        for peer_id, mode in zip(self.ordered_peer_ids, modes):
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
-                self.sender_weights.append(weight)
 
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
@@ -97,7 +97,6 @@ class AllReduceRunner(ServicerBase):
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
-            self.sender_weights,
         )
 
     def __repr__(self):
@@ -149,7 +148,9 @@ class AllReduceRunner(ServicerBase):
         if peer_id == self.peer_id:
             sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=self.weight
+                )
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
@@ -180,9 +181,10 @@ class AllReduceRunner(ServicerBase):
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             tensor_part=first_part,
+            weight=self.weight,
         )
         async for part in parts_aiter:
-            yield averaging_pb2.AveragingData(tensor_part=part)
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
 
     async def rpc_aggregate_part(
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
@@ -219,14 +221,16 @@ class AllReduceRunner(ServicerBase):
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
         loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, part_compression) in aenumerate(
+        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
             amap_in_executor(
-                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
             )
         ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+            averaged_part = await self.tensor_part_reducer.accumulate_part(
+                sender_index, part_index, tensor_part, weight=weight
+            )
 
             serialized_delta = await loop.run_in_executor(
                 None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)

+ 5 - 4
hivemind/averaging/averager.py

@@ -367,7 +367,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             while not future.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
+                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                     )
@@ -376,7 +376,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                     future.set_result(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            self._run_allreduce(
+                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                            ),
                             timeout=self._allreduce_timeout,
                         )
                     )
@@ -414,7 +416,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
 
@@ -435,7 +437,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    weights=weights,
                     gathered=user_gathered,
                     modes=modes,
                     **kwargs,

+ 6 - 8
hivemind/averaging/partition.py

@@ -167,15 +167,11 @@ class TensorPartReducer:
     Auxiliary data structure responsible for running asynchronous all-reduce
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
-    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
-        self.weights = tuple(weights or (1 for _ in range(num_senders)))
-        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
-        assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
@@ -197,7 +193,9 @@ class TensorPartReducer:
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
 
-    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+    async def accumulate_part(
+        self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
+    ) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
@@ -211,9 +209,9 @@ class TensorPartReducer:
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
+        self.accumulator.add_(tensor_part, alpha=weight)
         self.current_part_accumulated_from += 1
+        self.denominator += weight
 
         assert self.current_part_accumulated_from <= self.num_senders
         if self.current_part_accumulated_from == self.num_senders:

+ 1 - 1
hivemind/proto/averaging.proto

@@ -45,7 +45,7 @@ message AveragingData {
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
-  bytes metadata = 5;       // reserved user-extendable metadata
+  double weight = 5;        // peers will be averaged in proportion to these weights
 }
 
 message DownloadRequest {}

+ 2 - 2
tests/test_allreduce.py

@@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
     allreduce_protocols = []
-    for p2p in p2ps:
+    for i, p2p in enumerate(p2ps):
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             servicer_type=AllReduceRunner,
@@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
-            weights=averaging_weights,
+            weight=averaging_weights[i],
             part_size_bytes=part_size_bytes,
         )
         await allreduce_protocol.add_p2p_handlers(p2p)