|
@@ -1,252 +1,229 @@
|
|
|
import asyncio
|
|
|
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
|
|
|
+from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
|
|
|
+from enum import Enum
|
|
|
|
|
|
import grpc
|
|
|
import torch
|
|
|
|
|
|
-from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
|
|
|
-from hivemind.utils import split_for_streaming, combine_from_streaming
|
|
|
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
|
|
|
+from hivemind.utils import Endpoint, get_logger, ChannelCache
|
|
|
+from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
|
|
|
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
-from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
|
|
|
+from hivemind.proto import averaging_pb2_grpc, averaging_pb2
|
|
|
|
|
|
# flavour types
|
|
|
GroupID = bytes
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
-class AllReduceProtocol:
|
|
|
+class AveragingMode(Enum):
|
|
|
+ NODE = 0
|
|
|
+ CLIENT = 1
|
|
|
+ AUX = 2
|
|
|
+
|
|
|
+
|
|
|
+class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
"""
|
|
|
An internal class that runs butterfly AllReduce in a predefined group of averagers
|
|
|
|
|
|
+ :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
|
|
|
+ :param group_id: unique identifier of this specific all-reduce run
|
|
|
+ :param tensors: local tensors that should be averaged with groupmates
|
|
|
:param tensors: local tensors that should be averaged with groupmates
|
|
|
:param endpoint: your endpoint, must be included in ordered_group_endpoints
|
|
|
:param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
|
|
|
- :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
|
|
|
- :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
|
|
|
- default (False) - return averaged_tensors by themselves
|
|
|
+ :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
|
|
|
+ (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
|
|
|
+ :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
|
|
|
+ :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
|
|
|
+ :param gathered: additional user-defined data collected from this group
|
|
|
+ :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
|
|
|
- ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
|
|
|
+ def __init__(
|
|
|
+ self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
|
|
|
+ ordered_group_endpoints: Sequence[Endpoint], peer_fractions: Tuple[float, ...],
|
|
|
+ weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None,
|
|
|
+ gathered: Optional[Dict[Endpoint, Any]] = None, **kwargs):
|
|
|
assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
|
|
|
- self.group_id, self.endpoint = group_id, endpoint
|
|
|
- self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
|
|
|
- self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
|
|
|
- if part_size == 0}
|
|
|
- self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
|
|
|
- self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
|
|
|
- self.return_deltas = return_deltas
|
|
|
-
|
|
|
- self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint])
|
|
|
- self.denominator = 0.0 # number of peers added to accumulator or sum of their weights
|
|
|
- self.accumulated_from: Set[Endpoint] = set() # peers that we have accumulated our part from
|
|
|
- self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size]
|
|
|
- self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here
|
|
|
- self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception
|
|
|
- for endpoint in self.client_mode_endpoints:
|
|
|
- self.averaged_tensor_parts[endpoint] = torch.tensor([])
|
|
|
+ modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
|
|
|
+ weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
|
|
|
+ assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
|
|
|
+ assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
|
|
|
+ for mode, frac, weight in zip(modes, peer_fractions, weights):
|
|
|
+ assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
|
|
|
+ assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
|
|
|
+
|
|
|
+ self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
|
|
|
+ self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
|
|
|
+
|
|
|
+ self._future = asyncio.Future()
|
|
|
+
|
|
|
+ self.sender_endpoints, self.sender_weights = [], []
|
|
|
+ for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
|
|
|
+ if mode != AveragingMode.AUX:
|
|
|
+ self.sender_endpoints.append(endpoint)
|
|
|
+ self.sender_weights.append(weight)
|
|
|
+
|
|
|
+ endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
|
|
|
+ self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
|
|
|
+ self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
|
|
|
+ self.tensor_part_reducer = TensorPartReducer(tuple(part.shape for part in self.parts_for_local_averaging),
|
|
|
+ len(self.sender_endpoints), self.sender_weights)
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
|
|
|
|
|
|
- def __await__(self):
|
|
|
- return self.future.__await__()
|
|
|
+ def __aiter__(self):
|
|
|
+ return self.run()
|
|
|
|
|
|
def __contains__(self, endpoint: Endpoint):
|
|
|
- return endpoint in self.local_tensor_parts
|
|
|
+ return endpoint in self.ordered_group_endpoints
|
|
|
|
|
|
@property
|
|
|
def group_size(self):
|
|
|
return len(self.ordered_group_endpoints)
|
|
|
|
|
|
- async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
|
|
|
- """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
|
|
|
- assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
|
|
|
- assert not self.future.done(), f"already finished allreduce: {self.future}"
|
|
|
- assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
|
|
|
- assert source not in self.accumulated_from, "duplicate source, already received that part"
|
|
|
- assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
|
|
|
- assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
|
|
|
- logger.debug(f"{self} - accumulating tensor part from {source}")
|
|
|
-
|
|
|
- self.accumulator.add_(remote_part, alpha=weight)
|
|
|
- self.denominator += weight
|
|
|
- self.accumulated_from.add(source)
|
|
|
-
|
|
|
- assert len(self.accumulated_from) <= self.group_size
|
|
|
- if len(self.accumulated_from) == len(self.local_tensor_parts):
|
|
|
- average_result = self.accumulator.div_(self.denominator)
|
|
|
- self.register_averaged_part(self.endpoint, average_result)
|
|
|
- self.averaged_part.set_result(average_result)
|
|
|
-
|
|
|
- return await self.averaged_part
|
|
|
-
|
|
|
- def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
|
|
|
- assert not self.future.done(), f"already finished allreduce: {self.future}"
|
|
|
- assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
|
|
|
- assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
|
|
|
- assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
|
|
|
- assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
|
|
|
- logger.debug(f"{self} - receiving averaged tensor part from {source}")
|
|
|
- self.averaged_tensor_parts[source] = averaged_part
|
|
|
- if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
|
|
|
- ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
|
|
|
- outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
|
|
|
-
|
|
|
- if self.return_deltas:
|
|
|
- local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
|
|
|
- with torch.no_grad():
|
|
|
- original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
|
|
|
- for averaged_tensor, original_tensor in zip(outputs, original_tensors):
|
|
|
- averaged_tensor -= original_tensor
|
|
|
-
|
|
|
- self.future.set_result(outputs)
|
|
|
-
|
|
|
- def cancel(self) -> bool:
|
|
|
- if not self.future.done():
|
|
|
- logger.debug(f"{self} - cancelled")
|
|
|
- self.future.cancel()
|
|
|
- if not self.averaged_part.done():
|
|
|
- self.averaged_part.cancel()
|
|
|
- return True
|
|
|
- else:
|
|
|
- logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
|
|
|
- return False
|
|
|
-
|
|
|
- def set_exception(self, exception: Exception) -> bool:
|
|
|
- if not self.future.done():
|
|
|
- logger.debug(f"{self} - {exception}")
|
|
|
- self.future.set_exception(exception)
|
|
|
- if not self.averaged_part.done():
|
|
|
- self.averaged_part.cancel()
|
|
|
- return True
|
|
|
- else:
|
|
|
- logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
- """
|
|
|
- A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
|
|
|
- ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
|
|
|
- chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
|
|
|
- gathered: Dict[Endpoint, Any], return_deltas: bool = False):
|
|
|
- super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
|
|
|
- ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
|
|
|
- self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
|
|
|
- self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
|
|
|
-
|
|
|
def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
|
|
|
return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
|
|
|
|
- async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
|
|
|
- """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
|
|
|
- if peer_endpoint == self.endpoint:
|
|
|
- return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
|
|
|
- serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
|
|
|
- chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
|
|
|
-
|
|
|
- stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
|
|
|
- await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
|
|
|
- endpoint=self.endpoint, tensor_part=next(chunks)))
|
|
|
- for chunk in chunks:
|
|
|
- await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
|
|
|
- await stream.done_writing()
|
|
|
-
|
|
|
- outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
|
|
|
- code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
|
|
|
- if code != averaging_pb2.AVERAGED_PART:
|
|
|
- raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
|
|
|
- f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
|
|
|
- f" allreduce failed")
|
|
|
-
|
|
|
+ async def run(self) -> AsyncIterator[torch.Tensor]:
|
|
|
+ """ Run all-reduce, return differences between averaged and original tensors as they are computed """
|
|
|
+ pending_tasks = set()
|
|
|
try:
|
|
|
- averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
|
|
|
- [message.tensor_part for message in outputs]))
|
|
|
- except RuntimeError as e:
|
|
|
- raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
|
|
|
+ if len(self.sender_endpoints) == 0:
|
|
|
+ logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
|
|
|
+ self.finalize()
|
|
|
|
|
|
- self.register_averaged_part(peer_endpoint, averaged_part)
|
|
|
- return averaged_part
|
|
|
+ elif self.endpoint in self.sender_endpoints:
|
|
|
+ for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
|
|
|
+ if parts != 0:
|
|
|
+ pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
|
|
|
|
|
|
- async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
|
|
|
- stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
|
|
|
- await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
|
|
|
- await stream.done_writing()
|
|
|
+ async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
|
|
|
+ yield averaged_tensor_delta # delta = averaged_tensor - original_tensor
|
|
|
+ self.finalize()
|
|
|
+
|
|
|
+ else: # auxiliary peer
|
|
|
+ await self.tensor_part_reducer.finished.wait()
|
|
|
+ self.finalize()
|
|
|
|
|
|
- async def run(self) -> Sequence[torch.Tensor]:
|
|
|
- """
|
|
|
- send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
|
|
|
- """
|
|
|
- try:
|
|
|
- await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
|
|
|
- for i, peer in enumerate(self.ordered_group_endpoints)
|
|
|
- if peer not in self.client_mode_endpoints))
|
|
|
- return await self
|
|
|
except BaseException as e:
|
|
|
+ self.finalize(exception=e)
|
|
|
+ for task in pending_tasks:
|
|
|
+ task.cancel()
|
|
|
code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
|
|
|
logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
|
|
|
- self.set_exception(e)
|
|
|
- for peer_endpoint, part_size in zip(self.ordered_group_endpoints, self.part_sizes):
|
|
|
- if peer_endpoint != self.endpoint and part_size > 0:
|
|
|
+ for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
|
|
|
+ if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
|
|
|
asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
|
|
|
raise
|
|
|
|
|
|
- async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
|
|
|
- ) -> Iterable[runtime_pb2.Tensor]:
|
|
|
- """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
|
|
|
- try:
|
|
|
- tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
|
|
|
- except RuntimeError as e:
|
|
|
- raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
|
|
|
+ async def _communicate_with_peer(self, peer_endpoint: Endpoint):
|
|
|
+ """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
|
|
|
+ peer_index = self.ordered_group_endpoints.index(peer_endpoint)
|
|
|
+ if peer_endpoint == self.endpoint:
|
|
|
+ sender_index = self.sender_endpoints.index(peer_endpoint)
|
|
|
+ for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
|
|
|
+ averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
|
|
|
+ self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
|
|
|
|
|
|
- averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
|
|
|
- serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
|
|
|
- stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
|
|
|
- return stream_chunks
|
|
|
+ else:
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
|
|
|
+ write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
|
|
|
+
|
|
|
+ try:
|
|
|
+ code = None
|
|
|
+ async for part_index, msg in aenumerate(stream):
|
|
|
+ if code is None:
|
|
|
+ code = msg.code
|
|
|
+ averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
|
|
|
+ self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
|
|
|
+ await write_task
|
|
|
+
|
|
|
+ if code != averaging_pb2.AVERAGED_PART:
|
|
|
+ raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
|
|
|
+ f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
|
|
|
+ f", allreduce failed")
|
|
|
+ finally:
|
|
|
+ if not write_task.done():
|
|
|
+ write_task.cancel()
|
|
|
+
|
|
|
+ async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
|
|
|
+ parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
|
|
|
+ first_part = await anext(parts_aiter)
|
|
|
+ await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
|
|
|
+ group_id=self.group_id, endpoint=self.endpoint,
|
|
|
+ tensor_part=first_part))
|
|
|
+ async for part in parts_aiter:
|
|
|
+ await stream.write(averaging_pb2.AveragingData(tensor_part=part))
|
|
|
+
|
|
|
+ await stream.done_writing()
|
|
|
|
|
|
async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
|
|
|
) -> AsyncIterator[averaging_pb2.AveragingData]:
|
|
|
- """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
|
|
|
+ """ a peer sends us a part of his tensor; we should average it with other peers and return the difference """
|
|
|
request: averaging_pb2.AveragingData = await anext(stream)
|
|
|
-
|
|
|
- if request.group_id != self.group_id:
|
|
|
- yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
|
+ reason_to_reject = self._check_reasons_to_reject(request)
|
|
|
+ if reason_to_reject:
|
|
|
+ yield reason_to_reject
|
|
|
+ return
|
|
|
|
|
|
elif request.code == averaging_pb2.PART_FOR_AVERAGING:
|
|
|
try:
|
|
|
- tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
|
|
|
- averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
|
|
|
- yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
|
|
|
- for averaged_chunk in averaged_chunks:
|
|
|
- yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
|
|
|
+ sender_index = self.sender_endpoints.index(request.endpoint)
|
|
|
+ async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
|
|
|
+ yield msg
|
|
|
|
|
|
except Exception as e:
|
|
|
- self.set_exception(e)
|
|
|
+ self.finalize(exception=e)
|
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
else:
|
|
|
error_code = averaging_pb2.MessageCode.Name(request.code)
|
|
|
logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
|
|
|
- self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
|
|
|
+ self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
|
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
|
|
|
+ def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
|
|
|
+ if request.group_id != self.group_id:
|
|
|
+ return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
|
+ elif self._future.cancelled():
|
|
|
+ return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
|
|
|
+ elif self._future.done():
|
|
|
+ return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
+
|
|
|
+ async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ async for part_index, (tensor_part, part_compression) in aenumerate(
|
|
|
+ amap_in_executor(lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), stream,
|
|
|
+ max_prefetch=self.tensor_part_container.prefetch)):
|
|
|
+ averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
|
|
|
+
|
|
|
+ serialized_delta = await loop.run_in_executor(
|
|
|
+ None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression))
|
|
|
+ yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
|
|
|
|
|
|
-def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
|
|
|
- """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
|
|
|
- flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
|
|
|
- return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
|
|
|
-
|
|
|
-
|
|
|
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
|
|
|
- """ restores the original tensor shapes from chunks obtained by split_into_chunks """
|
|
|
- flat_tensor = torch.cat(tuple(chunks))
|
|
|
- result_sizes = tuple(map(torch.Size.numel, shapes))
|
|
|
- flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
|
|
|
- return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
|
|
|
-
|
|
|
+ async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
|
|
|
+ stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
|
|
|
+ await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
|
|
|
+ await stream.done_writing()
|
|
|
|
|
|
-class AllreduceException(Exception):
|
|
|
- """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """
|
|
|
+ def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
|
|
|
+ assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
|
|
|
+ if not self._future.done():
|
|
|
+ if cancel:
|
|
|
+ logger.debug(f"{self} - cancelled")
|
|
|
+ self._future.cancel()
|
|
|
+ elif exception:
|
|
|
+ logger.debug(f"{self} - caught {exception}")
|
|
|
+ self._future.set_exception(exception)
|
|
|
+ else:
|
|
|
+ logger.debug(f"{self} - finished")
|
|
|
+ self._future.set_result(None)
|
|
|
+ self.tensor_part_container.finalize()
|
|
|
+ self.tensor_part_reducer.finalize()
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
|
|
|
+ return False
|