|
|
@@ -2,7 +2,7 @@
|
|
|
Auxiliary data structures for AllReduceRunner
|
|
|
"""
|
|
|
import asyncio
|
|
|
-from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
|
|
|
+from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator, Type, Deque, List
|
|
|
from collections import deque
|
|
|
|
|
|
import torch
|
|
|
@@ -32,7 +32,7 @@ class TensorPartContainer:
|
|
|
self,
|
|
|
tensors: Sequence[torch.Tensor],
|
|
|
peer_fractions: Sequence[float],
|
|
|
- compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
|
|
|
+ compression_type: Union[Type[CompressionType], Sequence[Type[CompressionType]]] = CompressionType.NONE,
|
|
|
part_size_bytes: int = 2 ** 20,
|
|
|
prefetch: int = 1,
|
|
|
):
|
|
|
@@ -42,8 +42,10 @@ class TensorPartContainer:
|
|
|
self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
|
|
|
self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
|
|
|
self.total_size = sum(tensor.numel() for tensor in tensors)
|
|
|
- self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
|
|
|
- self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
|
|
|
+ self._input_parts_by_peer: List[Deque[Tuple[torch.Tensor, Type[CompressionType]]]] = [
|
|
|
+ deque() for _ in range(self.group_size)
|
|
|
+ ]
|
|
|
+ self._output_parts_by_peer: List[Deque[torch.Tensor]] = [deque() for _ in range(self.group_size)]
|
|
|
self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
|
|
|
self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
|
|
|
self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
|
|
|
@@ -124,7 +126,7 @@ class TensorPartContainer:
|
|
|
self._outputs_consumed = True
|
|
|
peer_index = num_parts_processed = 0
|
|
|
for tensor_index in range(len(self.local_tensors)):
|
|
|
- tensor_parts = []
|
|
|
+ tensor_parts: List[torch.Tensor] = []
|
|
|
while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
|
|
|
if num_parts_processed >= self.num_parts_by_peer[peer_index]:
|
|
|
num_parts_processed = 0
|
|
|
@@ -173,7 +175,7 @@ class TensorPartReducer:
|
|
|
assert all(isinstance(weight, (int, float)) for weight in self.weights)
|
|
|
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
|
|
|
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
|
|
|
- self.accumulator = None # this will contain the sum of current tensor part from group peers
|
|
|
+ self.accumulator: Optional[torch.Tensor] = None # contains the sum of current tensor part from group peers
|
|
|
self.denominator = 0.0 # total weight accumulated from all peers for current part
|
|
|
self.current_part_future = asyncio.Future()
|
|
|
self.finished = asyncio.Event()
|