|
@@ -1,6 +1,5 @@
|
|
import asyncio
|
|
import asyncio
|
|
import contextlib
|
|
import contextlib
|
|
-import faulthandler
|
|
|
|
import math
|
|
import math
|
|
import multiprocessing as mp
|
|
import multiprocessing as mp
|
|
from typing import Any, Iterable, Optional, Sequence
|
|
from typing import Any, Iterable, Optional, Sequence
|
|
@@ -8,37 +7,17 @@ from typing import Any, Iterable, Optional, Sequence
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
|
|
|
|
-from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
|
|
|
|
-from hivemind.averaging.control import AveragingStage, StepControl
|
|
|
|
|
|
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
-from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
|
|
-from hivemind.compression import (
|
|
|
|
- CompressionBase,
|
|
|
|
- CompressionInfo,
|
|
|
|
- NoCompression,
|
|
|
|
- TensorRole,
|
|
|
|
- deserialize_torch_tensor,
|
|
|
|
- serialize_torch_tensor,
|
|
|
|
-)
|
|
|
|
-from hivemind.dht import DHT, DHTID
|
|
|
|
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
|
-from hivemind.proto import averaging_pb2
|
|
|
|
-from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
|
-from hivemind.utils.asyncio import (
|
|
|
|
- achain,
|
|
|
|
- aiter_with_timeout,
|
|
|
|
- anext,
|
|
|
|
- as_aiter,
|
|
|
|
- azip,
|
|
|
|
- enter_asynchronously,
|
|
|
|
- switch_to_uvloop,
|
|
|
|
-)
|
|
|
|
-from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
|
-from hivemind.utils.math import orthogonalize_
|
|
|
|
-from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
|
-from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
|
|
|
|
+from hivemind.compression import CompressionInfo, TensorRole
|
|
|
|
+from hivemind.dht import DHT
|
|
|
|
+from hivemind.p2p import P2P
|
|
|
|
+from hivemind.utils import get_logger
|
|
|
|
+from hivemind.utils.asyncio import as_aiter, azip, enter_asynchronously
|
|
|
|
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
|
|
|
|
+from hivemind.utils.timed_storage import get_dht_time
|
|
|
|
|
|
from .grad_averager import GradientAverager
|
|
from .grad_averager import GradientAverager
|
|
|
|
|
|
@@ -50,26 +29,26 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
"""
|
|
"""
|
|
A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
|
|
A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
|
|
For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
|
|
For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
|
|
- Put simply, this method approximates large gradient tensors (m,n) with a product of two
|
|
|
|
|
|
+ Put simply, this method approximates large gradient tensors (m,n) with a product of two
|
|
smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
|
|
smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
|
|
-
|
|
|
|
|
|
+
|
|
As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
|
|
As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
|
|
High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
|
|
High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
|
|
Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
|
|
Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
|
|
-
|
|
|
|
|
|
+
|
|
To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
|
|
To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
|
|
if some part of the gradient is "lost in compression", it will be added to the next iteration.
|
|
if some part of the gradient is "lost in compression", it will be added to the next iteration.
|
|
This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
|
|
This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
|
|
and (b) if devices stay alive only for one step, training with small rank may converge slower.
|
|
and (b) if devices stay alive only for one step, training with small rank may converge slower.
|
|
This is because error feedback takes multiple step to kick in.
|
|
This is because error feedback takes multiple step to kick in.
|
|
-
|
|
|
|
|
|
+
|
|
Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
|
|
Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
|
|
If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
|
|
If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
|
|
normally, i.e. without powerSGD. See min_compression_ratio for details.
|
|
normally, i.e. without powerSGD. See min_compression_ratio for details.
|
|
-
|
|
|
|
|
|
+
|
|
:note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
|
|
:note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
|
|
matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
|
|
matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
|
|
-
|
|
|
|
|
|
+
|
|
:param parameters: pytorch parameters for which to aggregate gradients
|
|
:param parameters: pytorch parameters for which to aggregate gradients
|
|
:param averager_rank: compress gradient tensors
|
|
:param averager_rank: compress gradient tensors
|
|
:param min_comprasion_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
|
|
:param min_comprasion_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
|
|
@@ -84,6 +63,7 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
if True, the averager will only join existing groups where at least one peer has client_mode=False.
|
|
if True, the averager will only join existing groups where at least one peer has client_mode=False.
|
|
By default, this flag is copied from DHTNode inside the ``dht`` instance.
|
|
By default, this flag is copied from DHTNode inside the ``dht`` instance.
|
|
"""
|
|
"""
|
|
|
|
+
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
parameters: Iterable[torch.nn.Parameter],
|
|
parameters: Iterable[torch.nn.Parameter],
|
|
@@ -104,18 +84,19 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
self._uncompressed_gradients_indexes = set(
|
|
self._uncompressed_gradients_indexes = set(
|
|
i
|
|
i
|
|
for i, grad in enumerate(self._grads_from_parameters())
|
|
for i, grad in enumerate(self._grads_from_parameters())
|
|
- if len(tuple(grad.size())) == 1
|
|
|
|
|
|
+ if len(tuple(grad.size())) <= 1
|
|
or (
|
|
or (
|
|
- 1 - self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) < min_compression_ratio
|
|
|
|
- ) # compute how much parameters can we left via factorization
|
|
|
|
|
|
+ 1 - self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size())
|
|
|
|
+ < min_compression_ratio
|
|
|
|
+ ) # compute how much parameters can we left via factorization
|
|
)
|
|
)
|
|
self._ms = [
|
|
self._ms = [
|
|
- torch.zeros_like(grad, device="cpu").share_memory_()
|
|
|
|
|
|
+ torch.zeros_like(grad, device="cpu").share_memory_()
|
|
for idx, grad in enumerate(self._grads_from_parameters())
|
|
for idx, grad in enumerate(self._grads_from_parameters())
|
|
if idx not in self._uncompressed_gradients_indexes
|
|
if idx not in self._uncompressed_gradients_indexes
|
|
]
|
|
]
|
|
self._qs = [
|
|
self._qs = [
|
|
- torch.rand((np.prod(grad.size()[1:]), self.rank), device="cpu").share_memory_()
|
|
|
|
|
|
+ torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
|
|
for idx, grad in enumerate(self._grads_from_parameters())
|
|
for idx, grad in enumerate(self._grads_from_parameters())
|
|
if idx not in self._uncompressed_gradients_indexes
|
|
if idx not in self._uncompressed_gradients_indexes
|
|
]
|
|
]
|
|
@@ -172,7 +153,7 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
m.add_(grad.to(m.device))
|
|
m.add_(grad.to(m.device))
|
|
|
|
|
|
ps = [
|
|
ps = [
|
|
- torch.zeros((grad.size(0), self.rank), device="cpu")
|
|
|
|
|
|
+ torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
|
|
for idx, grad in enumerate(averaged_grad_via_sgd)
|
|
for idx, grad in enumerate(averaged_grad_via_sgd)
|
|
]
|
|
]
|
|
for p, q, m in zip(ps, self._qs, self._ms):
|
|
for p, q, m in zip(ps, self._qs, self._ms):
|