|
@@ -9,7 +9,15 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
|
|
|
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
|
|
|
+from hivemind.utils.asyncio import (
|
|
|
+ achain,
|
|
|
+ aenumerate,
|
|
|
+ afirst,
|
|
|
+ amap_in_executor,
|
|
|
+ anext,
|
|
|
+ as_aiter,
|
|
|
+ attach_event_on_finished,
|
|
|
+)
|
|
|
|
|
|
# flavour types
|
|
|
GroupID = bytes
|
|
@@ -44,7 +52,10 @@ class AllReduceRunner(ServicerBase):
|
|
|
(the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
|
|
|
:param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
|
|
|
:param gathered: additional user-defined data collected from this group
|
|
|
- :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
|
|
|
+ :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
|
|
|
+ :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
|
|
|
+ non-averaging peers receive results only after they finish sending, which helps them avoid
|
|
|
+ throughput issues in case of asymmetric high-latency connections (e.g. ACK compression).
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
@@ -115,6 +126,9 @@ class AllReduceRunner(ServicerBase):
|
|
|
def _get_peer_stub(self, peer: PeerID) -> StubBase:
|
|
|
return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
|
|
|
|
|
|
+ def should_delay_results(self, peer_id: PeerID) -> bool:
|
|
|
+ return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0
|
|
|
+
|
|
|
async def run(self) -> AsyncIterator[torch.Tensor]:
|
|
|
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
|
|
|
pending_tasks = set()
|
|
@@ -155,7 +169,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
|
|
|
else:
|
|
|
code = None
|
|
|
- stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
|
|
|
+ stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
|
|
|
async for part_index, (averaged_part_delta, msg) in aenumerate(
|
|
|
amap_in_executor(
|
|
|
lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
|
|
@@ -199,8 +213,31 @@ class AllReduceRunner(ServicerBase):
|
|
|
elif request.code == averaging_pb2.PART_FOR_AVERAGING:
|
|
|
try:
|
|
|
sender_index = self.sender_peer_ids.index(context.remote_id)
|
|
|
- async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
|
|
|
- yield msg
|
|
|
+
|
|
|
+ if not self.should_delay_results(context.remote_id):
|
|
|
+ async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
|
|
|
+ yield msg
|
|
|
+
|
|
|
+ else:
|
|
|
+ done_receiving = asyncio.Event()
|
|
|
+ delayed_results = asyncio.Queue()
|
|
|
+
|
|
|
+ async def _accumulate_parts():
|
|
|
+ inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
|
|
|
+ async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
|
|
|
+ delayed_results.put_nowait(msg)
|
|
|
+ delayed_results.put_nowait(None)
|
|
|
+
|
|
|
+ accumulate_task = asyncio.create_task(_accumulate_parts())
|
|
|
+
|
|
|
+ await done_receiving.wait()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ next_result = await delayed_results.get()
|
|
|
+ if next_result is None:
|
|
|
+ break
|
|
|
+ yield next_result
|
|
|
+ await accumulate_task
|
|
|
|
|
|
except Exception as e:
|
|
|
self.finalize(exception=e)
|
|
@@ -239,8 +276,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
|
|
|
async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
|
|
|
error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
|
|
|
- # Coroutines are lazy, so we take the first item to start the couroutine's execution
|
|
|
- await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
|
|
|
+ await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
|
|
|
|
|
|
def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
|
|
|
"""finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
|