Przeglądaj źródła

another pack of fixes

Artem Chumachenko 3 lat temu
rodzic
commit
a9bbf9a190

+ 1 - 1
hivemind/averaging/averager.py

@@ -509,7 +509,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     @contextlib.contextmanager
     def _register_allreduce_group(self, group_info: GroupInfo):
-        """Register a given group all-reduce for one or more all-reduce rounds"""
+        """Register a given group for one or more all-reduce rounds"""
         try:
             self._running_groups[group_info.group_id] = asyncio.Future()
             self._pending_groups_registered.set()

+ 1 - 1
hivemind/optim/grad_averager.py

@@ -36,7 +36,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
-    :param average_grads: if provided, it will be used as a set of averagable gradients
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 

+ 1 - 1
hivemind/optim/optimizer.py

@@ -251,7 +251,7 @@ class Optimizer(torch.optim.Optimizer):
             optimizer=optimizer,
             params=params,
             scheduler=scheduler,
-            delta_rule_averaging=grad_averager is None and self.delay_state_averaging,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
             compression=state_averaging_compression,
             state_compression=load_state_compression,
             average_opt_statistics=average_opt_statistics,

+ 22 - 41
hivemind/optim/power_sgd_averager.py

@@ -1,6 +1,5 @@
 import asyncio
 import contextlib
-import faulthandler
 import math
 import multiprocessing as mp
 from typing import Any, Iterable, Optional, Sequence
@@ -8,37 +7,17 @@ from typing import Any, Iterable, Optional, Sequence
 import numpy as np
 import torch
 
-from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
-from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
-from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.compression import (
-    CompressionBase,
-    CompressionInfo,
-    NoCompression,
-    TensorRole,
-    deserialize_torch_tensor,
-    serialize_torch_tensor,
-)
-from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.proto import averaging_pb2
-from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import (
-    achain,
-    aiter_with_timeout,
-    anext,
-    as_aiter,
-    azip,
-    enter_asynchronously,
-    switch_to_uvloop,
-)
-from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
-from hivemind.utils.math import orthogonalize_
-from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.dht import DHT
+from hivemind.p2p import P2P
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import as_aiter, azip, enter_asynchronously
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
+from hivemind.utils.timed_storage import get_dht_time
 
 from .grad_averager import GradientAverager
 
@@ -50,26 +29,26 @@ class PowerSGDGradientAverager(GradientAverager):
     """
     A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
     For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
-    Put simply, this method approximates large gradient tensors (m,n) with a product of two  
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two
     smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
-    
+
     As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
     High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
     Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
-    
+
     To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
     if some part of the gradient is "lost in compression", it will be added to the next iteration.
     This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
     and (b) if devices stay alive only for one step, training with small rank may converge slower.
     This is because error feedback takes multiple step to kick in.
-    
+
     Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
     If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
     normally, i.e. without powerSGD. See min_compression_ratio for details.
-    
+
     :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
      matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
-    
+
     :param parameters: pytorch parameters for which to aggregate gradients
     :param averager_rank: compress gradient tensors
     :param min_comprasion_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
@@ -84,6 +63,7 @@ class PowerSGDGradientAverager(GradientAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     """
+
     def __init__(
         self,
         parameters: Iterable[torch.nn.Parameter],
@@ -104,18 +84,19 @@ class PowerSGDGradientAverager(GradientAverager):
         self._uncompressed_gradients_indexes = set(
             i
             for i, grad in enumerate(self._grads_from_parameters())
-            if len(tuple(grad.size())) == 1
+            if len(tuple(grad.size())) <= 1
             or (
-                1 - self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) < min_compression_ratio
-            ) # compute how much parameters can we left via factorization
+                1 - self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size())
+                < min_compression_ratio
+            )  # compute how much parameters can we left via factorization
         )
         self._ms = [
-            torch.zeros_like(grad, device="cpu").share_memory_() 
+            torch.zeros_like(grad, device="cpu").share_memory_()
             for idx, grad in enumerate(self._grads_from_parameters())
             if idx not in self._uncompressed_gradients_indexes
         ]
         self._qs = [
-            torch.rand((np.prod(grad.size()[1:]), self.rank), device="cpu").share_memory_()
+            torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
             for idx, grad in enumerate(self._grads_from_parameters())
             if idx not in self._uncompressed_gradients_indexes
         ]
@@ -172,7 +153,7 @@ class PowerSGDGradientAverager(GradientAverager):
                     m.add_(grad.to(m.device))
 
                 ps = [
-                    torch.zeros((grad.size(0), self.rank), device="cpu")
+                    torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
                     for idx, grad in enumerate(averaged_grad_via_sgd)
                 ]
                 for p, q, m in zip(ps, self._qs, self._ms):

+ 11 - 1
hivemind/utils/math.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn.functional as F
 
 
-@torch.jit.script      
+@torch.jit.script
 def orthogonalize_(matrix, eps: float = 1e-8):
     """Orthogonalize a 2d tensor in-place over the last dimension"""
     n, m = matrix.shape
@@ -12,3 +12,13 @@ def orthogonalize_(matrix, eps: float = 1e-8):
         if i + 1 < m:
             rest = matrix[:, i + 1 :]
             rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
+
+
+def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
+    """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+    dims = list(tensor.shape)
+    while len(dims) > max_ndim:
+        squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
+        squeezed_dim = dims.pop(squeeze_ix)
+        dims[squeeze_ix] *= squeezed_dim
+    return dims