|
@@ -138,7 +138,6 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
|
|
|
ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
|
|
|
self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
|
|
|
self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
|
|
|
- self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
|
|
|
|
|
|
def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
|
|
|
return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
@@ -165,7 +164,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
|
|
|
f" allreduce failed")
|
|
|
|
|
|
try:
|
|
|
- averaged_part = deserialize_torch_tensor(combine_from_streaming(
|
|
|
+ averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
|
|
|
[message.tensor_part for message in outputs]))
|
|
|
except RuntimeError as e:
|
|
|
raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
|
|
@@ -205,17 +204,13 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
|
|
|
raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
|
|
|
|
|
|
averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
|
|
|
- if not self.averaged_part_stream.done():
|
|
|
- serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
|
|
|
- stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
|
|
|
- self.averaged_part_stream.set_result(stream_chunks)
|
|
|
- return stream_chunks
|
|
|
- else:
|
|
|
- return self.averaged_part_stream.result()
|
|
|
+ serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
|
|
|
+ stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
|
|
|
+ return stream_chunks
|
|
|
|
|
|
async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
|
|
|
) -> AsyncIterator[averaging_pb2.AveragingData]:
|
|
|
- """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
|
|
|
+ """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
|
|
|
request: averaging_pb2.AveragingData = await anext(stream)
|
|
|
|
|
|
if request.group_id != self.group_id:
|