|
@@ -8,6 +8,7 @@ from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, Type
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
|
|
|
|
|
|
+from hivemind.averaging.accumulators import AccumulatorFactory
|
|
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
|
|
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.utils.asyncio import amap_in_executor
|
|
from hivemind.utils.asyncio import amap_in_executor
|
|
@@ -171,16 +172,23 @@ class TensorPartReducer:
|
|
:note: even if local peer is not sending data, local parts will be used for shape information
|
|
:note: even if local peer is not sending data, local parts will be used for shape information
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
|
|
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ part_shapes: Sequence[torch.Size],
|
|
|
|
+ num_senders: int,
|
|
|
|
+ *,
|
|
|
|
+ weights: Optional[Sequence[float]],
|
|
|
|
+ accumulator_factory: AccumulatorFactory,
|
|
|
|
+ ):
|
|
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
|
|
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
|
|
self.weights = tuple(weights or (1 for _ in range(num_senders)))
|
|
self.weights = tuple(weights or (1 for _ in range(num_senders)))
|
|
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
|
|
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
|
|
assert all(isinstance(weight, (int, float)) for weight in self.weights)
|
|
assert all(isinstance(weight, (int, float)) for weight in self.weights)
|
|
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
|
|
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
|
|
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
|
|
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
|
|
- self.accumulator = None # this will contain the sum of current tensor part from group peers
|
|
|
|
- self.denominator = 0.0 # total weight accumulated from all peers for current part
|
|
|
|
self.current_part_future = asyncio.Future()
|
|
self.current_part_future = asyncio.Future()
|
|
|
|
+ self.accumulator_factory = accumulator_factory
|
|
|
|
+ self.accumulator = None
|
|
self.finished = asyncio.Event()
|
|
self.finished = asyncio.Event()
|
|
self.reset_accumulators()
|
|
self.reset_accumulators()
|
|
|
|
|
|
@@ -194,8 +202,7 @@ class TensorPartReducer:
|
|
self.current_part_index += 1
|
|
self.current_part_index += 1
|
|
self.current_part_accumulated_from = 0
|
|
self.current_part_accumulated_from = 0
|
|
self.current_part_future = asyncio.Future()
|
|
self.current_part_future = asyncio.Future()
|
|
- self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
|
|
|
|
- self.denominator = 0.0
|
|
|
|
|
|
+ self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
|
|
|
|
|
|
async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
|
|
async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
|
|
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
|
|
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
|
|
@@ -211,13 +218,12 @@ class TensorPartReducer:
|
|
|
|
|
|
current_part_future = self.current_part_future
|
|
current_part_future = self.current_part_future
|
|
|
|
|
|
- self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
|
|
|
|
- self.denominator += self.weights[sender_index]
|
|
|
|
|
|
+ self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
|
|
self.current_part_accumulated_from += 1
|
|
self.current_part_accumulated_from += 1
|
|
|
|
|
|
assert self.current_part_accumulated_from <= self.num_senders
|
|
assert self.current_part_accumulated_from <= self.num_senders
|
|
if self.current_part_accumulated_from == self.num_senders:
|
|
if self.current_part_accumulated_from == self.num_senders:
|
|
- current_part_future.set_result(self.accumulator.div_(self.denominator))
|
|
|
|
|
|
+ current_part_future.set_result(self.accumulator.reduce())
|
|
self.reset_accumulators()
|
|
self.reset_accumulators()
|
|
return await current_part_future
|
|
return await current_part_future
|
|
|
|
|
|
@@ -225,7 +231,7 @@ class TensorPartReducer:
|
|
if not self.finished.is_set():
|
|
if not self.finished.is_set():
|
|
if hasattr(self, "current_part_future"):
|
|
if hasattr(self, "current_part_future"):
|
|
self.current_part_future.cancel()
|
|
self.current_part_future.cancel()
|
|
- del self.accumulator
|
|
|
|
|
|
+ self.accumulator = None
|
|
self.finished.set()
|
|
self.finished.set()
|
|
|
|
|
|
def __del__(self):
|
|
def __del__(self):
|