소스 검색

Delta gradients transmission (#225)

* Implemented delta gradient transmission
Vsevolod-pl 4 년 전
부모
커밋
91d17a4ebc
1개의 변경된 파일5개의 추가작업 그리고 10개의 파일을 삭제
  1. 5 10
      hivemind/client/averaging/allreduce.py

+ 5 - 10
hivemind/client/averaging/allreduce.py

@@ -138,7 +138,6 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
-        self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
@@ -165,7 +164,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
                                      f" allreduce failed")
 
         try:
-            averaged_part = deserialize_torch_tensor(combine_from_streaming(
+            averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
                 [message.tensor_part for message in outputs]))
         except RuntimeError as e:
             raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
@@ -205,17 +204,13 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
             raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
 
         averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
-        if not self.averaged_part_stream.done():
-            serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
-            stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
-            self.averaged_part_stream.set_result(stream_chunks)
-            return stream_chunks
-        else:
-            return self.averaged_part_stream.result()
+        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
+        stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
+        return stream_chunks
 
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
         request: averaging_pb2.AveragingData = await anext(stream)
 
         if request.group_id != self.group_id: