|
@@ -37,13 +37,12 @@ class AllReduceRunner(ServicerBase):
|
|
|
:param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
|
|
|
:param group_id: unique identifier of this specific all-reduce run
|
|
|
:param tensors: local tensors that should be averaged with groupmates
|
|
|
- :param tensors: local tensors that should be averaged with groupmates
|
|
|
+ :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
|
|
|
:param peer_id: your peer_id, must be included in ordered_peer_ids
|
|
|
:param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
|
|
|
:param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
|
|
|
(the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
|
|
|
:param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
|
|
|
- :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
|
|
|
:param gathered: additional user-defined data collected from this group
|
|
|
:param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
|
|
|
"""
|
|
@@ -56,9 +55,9 @@ class AllReduceRunner(ServicerBase):
|
|
|
prefix: Optional[str],
|
|
|
group_id: GroupID,
|
|
|
tensors: Sequence[torch.Tensor],
|
|
|
+ weight: Optional[float] = None,
|
|
|
ordered_peer_ids: Sequence[PeerID],
|
|
|
peer_fractions: Tuple[float, ...],
|
|
|
- weights: Optional[Sequence[float]] = None,
|
|
|
modes: Optional[Sequence[AveragingMode]] = None,
|
|
|
gathered: Optional[Dict[PeerID, Any]] = None,
|
|
|
**kwargs,
|
|
@@ -73,23 +72,24 @@ class AllReduceRunner(ServicerBase):
|
|
|
self._prefix = prefix
|
|
|
|
|
|
modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
|
|
|
- weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
|
|
|
- assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
|
|
|
+ assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
|
|
|
assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
|
|
|
- for mode, frac, weight in zip(modes, peer_fractions, weights):
|
|
|
+ for mode, frac in zip(modes, peer_fractions):
|
|
|
assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
|
|
|
- assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
|
|
|
|
|
|
self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
|
|
|
self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
|
|
|
|
|
|
+ if weight is None:
|
|
|
+ weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
|
|
|
+ self.weight = weight
|
|
|
+
|
|
|
self._future = asyncio.Future()
|
|
|
|
|
|
- self.sender_peer_ids, self.sender_weights = [], []
|
|
|
- for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
|
|
|
+ self.sender_peer_ids = []
|
|
|
+ for peer_id, mode in zip(self.ordered_peer_ids, modes):
|
|
|
if mode != AveragingMode.AUX:
|
|
|
self.sender_peer_ids.append(peer_id)
|
|
|
- self.sender_weights.append(weight)
|
|
|
|
|
|
peer_id_index = self.ordered_peer_ids.index(self.peer_id)
|
|
|
self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
|
|
@@ -97,7 +97,6 @@ class AllReduceRunner(ServicerBase):
|
|
|
self.tensor_part_reducer = TensorPartReducer(
|
|
|
tuple(part.shape for part in self.parts_for_local_averaging),
|
|
|
len(self.sender_peer_ids),
|
|
|
- self.sender_weights,
|
|
|
)
|
|
|
|
|
|
def __repr__(self):
|
|
@@ -149,7 +148,9 @@ class AllReduceRunner(ServicerBase):
|
|
|
if peer_id == self.peer_id:
|
|
|
sender_index = self.sender_peer_ids.index(peer_id)
|
|
|
for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
|
|
|
- averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
|
|
|
+ averaged_part = await self.tensor_part_reducer.accumulate_part(
|
|
|
+ sender_index, part_index, tensor_part, weight=self.weight
|
|
|
+ )
|
|
|
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
|
|
|
|
|
|
else:
|
|
@@ -180,9 +181,10 @@ class AllReduceRunner(ServicerBase):
|
|
|
code=averaging_pb2.PART_FOR_AVERAGING,
|
|
|
group_id=self.group_id,
|
|
|
tensor_part=first_part,
|
|
|
+ weight=self.weight,
|
|
|
)
|
|
|
async for part in parts_aiter:
|
|
|
- yield averaging_pb2.AveragingData(tensor_part=part)
|
|
|
+ yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
|
|
|
|
|
|
async def rpc_aggregate_part(
|
|
|
self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
|
|
@@ -219,14 +221,16 @@ class AllReduceRunner(ServicerBase):
|
|
|
|
|
|
async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- async for part_index, (tensor_part, part_compression) in aenumerate(
|
|
|
+ async for part_index, (tensor_part, weight, part_compression) in aenumerate(
|
|
|
amap_in_executor(
|
|
|
- lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
|
|
|
+ lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
|
|
|
stream,
|
|
|
max_prefetch=self.tensor_part_container.prefetch,
|
|
|
)
|
|
|
):
|
|
|
- averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
|
|
|
+ averaged_part = await self.tensor_part_reducer.accumulate_part(
|
|
|
+ sender_index, part_index, tensor_part, weight=weight
|
|
|
+ )
|
|
|
|
|
|
serialized_delta = await loop.run_in_executor(
|
|
|
None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
|