Artem Chumachenko 3 ani în urmă
părinte
comite
795ca8c30c

+ 1 - 1
hivemind/averaging/averager.py

@@ -456,7 +456,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     group_info = await matchmaking_task
 
                     if group_info is None:
-                         raise AllreduceException("Averaging step failed: could not find a group")
+                        raise AllreduceException("Averaging step failed: could not find a group")
 
                     with self._register_allreduce_group(group_info):
                         step.stage = AveragingStage.RUNNING_ALLREDUCE

+ 4 - 3
hivemind/optim/optimizer.py

@@ -12,8 +12,8 @@ from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.optim.grad_averager import GradientAverager
-from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
     LRSchedulerBase,
@@ -248,7 +248,8 @@ class Optimizer(torch.optim.Optimizer):
             assert len(extra_tensors) == 0
             grad_extra_tensors = [
                 torch.zeros_like(param, device="cpu")
-                for param_group in optimizer.param_groups for param in param_group["params"]
+                for param_group in optimizer.param_groups
+                for param in param_group["params"]
             ]
             for tensor in grad_extra_tensors:
                 if tensor is not None:
@@ -272,7 +273,7 @@ class Optimizer(torch.optim.Optimizer):
                 reuse_grad_buffers=reuse_grad_buffers,
                 grad_rank_averager=grad_rank_averager,
                 compression=grad_compression,
-                **grad_averager_opts or {}
+                **grad_averager_opts or {},
             )
         else:
             self.grad_averager = None

+ 29 - 21
hivemind/optim/power_ef_averager.py

@@ -2,12 +2,12 @@ import asyncio
 import contextlib
 import faulthandler
 import math
-import torch
 import multiprocessing as mp
-import numpy as np
-
 from typing import Any, Iterable, Optional, Sequence
 
+import numpy as np
+import torch
+
 import hivemind
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.control import AveragingStage, StepControl
@@ -64,16 +64,20 @@ class PowerEFGradientAverager(GradientAverager):
         self.rank = averager_rank
         self.parameters = tuple(parameters)
         self._uncompressed_gradients = set(
-            i for i, grad in enumerate(self._grads_from_parameters())
-            if len(tuple(grad.size())) == 1 or 
-            (self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio)
+            i
+            for i, grad in enumerate(self._grads_from_parameters())
+            if len(tuple(grad.size())) == 1
+            or (
+                self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
+            )
         )
         self._gradient_rests = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
         self._qs = list(
             torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
-            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients
         )
-        for tensor in (self._qs + self._gradient_rests):
+        for tensor in self._qs + self._gradient_rests:
             if tensor is not None:
                 assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
                 tensor.share_memory_()
@@ -87,22 +91,22 @@ class PowerEFGradientAverager(GradientAverager):
             client_mode=client_mode,
             warn=warn,
             grad_extra_tensors=grad_extra_tensors,
-            **kwargs
+            **kwargs,
         )
 
     @contextlib.contextmanager
     def _register_allreduce_group(self, group_info: GroupInfo):
         """registers a given all-reduce runner to listen for incoming connections"""
         try:
-            self._running_groups[group_info.group_id + b'.phase1'] = asyncio.Future()
-            self._running_groups[group_info.group_id + b'.phase2'] = asyncio.Future()
+            self._running_groups[group_info.group_id + b".phase1"] = asyncio.Future()
+            self._running_groups[group_info.group_id + b".phase2"] = asyncio.Future()
             self._pending_groups_registered.set()
             yield
         finally:
-            maybe_future = self._running_groups.pop(group_info.group_id + b'.phase1', None)
+            maybe_future = self._running_groups.pop(group_info.group_id + b".phase1", None)
             if maybe_future and not maybe_future.done():
                 logger.warning(f"All-reduce group {group_info.group_id + b'.phase1'} did not finish.")
-            maybe_future = self._running_groups.pop(group_info.group_id + b'.phase2', None)
+            maybe_future = self._running_groups.pop(group_info.group_id + b".phase2", None)
             if maybe_future and not maybe_future.done():
                 logger.warning(f"All-reduce group {group_info.group_id + b'.phase2'} did not finish.")
             self._pending_groups_registered.set()
@@ -122,18 +126,22 @@ class PowerEFGradientAverager(GradientAverager):
             )
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                compressed_tensors = [lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
+                compressed_tensors = [
+                    lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients
+                ]
 
                 cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
                 ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
                 for p, q, rest in zip(ps, self._qs, cs):
                     torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
-                first_all_reduced = ps + [rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients]
+                first_all_reduced = ps + [
+                    rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients
+                ]
                 allreduce1 = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
                     prefix=self.prefix,
-                    group_id=group_info.group_id + b'.phase1',
+                    group_id=group_info.group_id + b".phase1",
                     tensors=first_all_reduced,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
@@ -141,7 +149,7 @@ class PowerEFGradientAverager(GradientAverager):
                     modes=modes,
                     **kwargs,
                 )
-                self._running_groups[group_info.group_id + b'.phase1'].set_result(allreduce1)
+                self._running_groups[group_info.group_id + b".phase1"].set_result(allreduce1)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
@@ -163,7 +171,7 @@ class PowerEFGradientAverager(GradientAverager):
                     p2p=self._p2p,
                     servicer_type=type(self),
                     prefix=self.prefix,
-                    group_id=group_info.group_id + b'.phase2',
+                    group_id=group_info.group_id + b".phase2",
                     tensors=self._qs,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
@@ -171,7 +179,7 @@ class PowerEFGradientAverager(GradientAverager):
                     modes=modes,
                     **kwargs,
                 )
-                self._running_groups[group_info.group_id + b'.phase2'].set_result(allreduce2)
+                self._running_groups[group_info.group_id + b".phase2"].set_result(allreduce2)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
@@ -212,8 +220,8 @@ class PowerEFGradientAverager(GradientAverager):
 def orthogonalize(matrix, eps=torch.tensor(1e-8)):
     n, m = matrix.shape
     for i in range(m):
-        col = matrix[:, i: i + 1]
+        col = matrix[:, i : i + 1]
         col /= torch.sqrt(torch.sum(col ** 2)) + eps
         if i + 1 < m:
-            rest = matrix[:, i + 1:]
+            rest = matrix[:, i + 1 :]
             rest -= torch.sum(col * rest, dim=0) * col