Artem Chumachenko 3 tahun lalu
induk
melakukan
413e06741f
1 mengubah file dengan 6 tambahan dan 6 penghapusan
  1. 6 6
      hivemind/optim/power_sgd_averager.py

+ 6 - 6
hivemind/optim/power_sgd_averager.py

@@ -4,7 +4,6 @@ import math
 import multiprocessing as mp
 from typing import Any, Iterable, Optional, Sequence
 
-import numpy as np
 import torch
 
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
@@ -13,14 +12,12 @@ from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.dht import DHT
-from hivemind.p2p import P2P
+from hivemind.optim.grad_averager import GradientAverager
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import as_aiter, azip, enter_asynchronously
 from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
 from hivemind.utils.timed_storage import get_dht_time
 
-from .grad_averager import GradientAverager
-
 GatheredData = Any
 logger = get_logger(__name__)
 
@@ -51,7 +48,7 @@ class PowerSGDGradientAverager(GradientAverager):
 
     :param parameters: pytorch parameters for which to aggregate gradients
     :param averager_rank: compress gradient tensors
-    :param min_comprasion_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
+    :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
     :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
     :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
     :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
@@ -86,7 +83,10 @@ class PowerSGDGradientAverager(GradientAverager):
             for i, grad in enumerate(self._grads_from_parameters())
             if len(tuple(grad.size())) <= 1
             or (
-                1 - self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size())
+                1
+                - self.rank
+                * sum(get_flatten_greedy_dims(grad))
+                / (get_flatten_greedy_dims(grad)[0] * get_flatten_greedy_dims(grad)[1])
                 < min_compression_ratio
             )  # compute how much parameters can we left via factorization
         )