Artem Chumachenko hace 3 años
padre
commit
25e7e1216e

+ 0 - 77
examples/example.py

@@ -1,77 +0,0 @@
-import time
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torchvision import datasets, transforms
-from tqdm.auto import tqdm
-
-import hivemind
-
-
-class SmallCNN(nn.Module):
-    def __init__(self):
-        super().__init__()
-
-        self.features = nn.Sequential(
-            nn.Conv2d(1, 16, (9, 9)),
-            nn.ReLU(),
-            nn.Conv2d(16, 16, (9, 9)),
-            nn.ReLU(),
-            nn.MaxPool2d(2)
-        )
-
-        self.cls = nn.Sequential(
-            nn.Linear(16 * 6 * 6, 400),
-            nn.ReLU(),
-            nn.Linear(400, 10)
-        )
-
-    def forward(self, x):
-        feature = self.features(x)
-        return self.cls(feature.view(x.size(0), -1))
-
-
-if __name__ == "__main__":
-    # Create dataset and model, same as in the basic tutorial
-    # For this basic tutorial, we download only the training set
-    transform = transforms.Compose([transforms.ToTensor()])
-
-    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
-
-    model = SmallCNN()
-    opt = torch.optim.Adam(model.parameters(), lr=0.001)
-
-    # Create DHT: a decentralized key-value storage shared between peers
-    dht = hivemind.DHT(start=True, initial_peers=["/ip4/127.0.0.1/tcp/36805/p2p/Qmc7nJt6Pc3Eii4X1ZqtkxbiRWvf97nNfuD4CJpAep5THU"])
-    print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
-
-    # Set up a decentralized optimizer that will average with peers in background
-    opt = hivemind.Optimizer(
-        dht=dht,                  # use a DHT that is connected with other peers
-        run_id='my_cifar_run',    # unique identifier of this collaborative run
-        batch_size_per_step=16,   # each call to opt.step adds this many samples towards the next epoch
-        target_batch_size=1000,  # after peers collectively process this many samples, average weights and begin the next epoch 
-        optimizer=opt,            # wrap the SGD optimizer defined above
-        use_local_updates=False,  # perform optimizer steps with averaged gradients
-        matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
-        averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
-        verbose=True,             # print logs incessently
-        grad_rank_averager="power_sgd",
-        grad_averager_opts={"averager_rank": 1}
-    )
-    opt.load_state_from_peers()
-
-    # Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
-    with tqdm() as progressbar:
-        while True:
-            for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=16):
-                time.sleep(0.1)
-                opt.zero_grad()
-                loss = F.cross_entropy(model(x_batch), y_batch)
-                loss.backward()
-                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
-                opt.step()
-
-                progressbar.desc = f"loss = {loss.item():.3f}"
-                progressbar.update()

+ 8 - 4
hivemind/optim/grad_averager.py

@@ -1,5 +1,5 @@
 import contextlib
-from typing import Iterable, Iterator, Optional
+from typing import Iterable, Iterator, Optional, Sequence
 
 import torch
 
@@ -75,6 +75,7 @@ class GradientAverager(DecentralizedAverager):
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         warn: bool = True,
+        grad_extra_tensors: Sequence[torch.Tensor] = (),
         **kwargs,
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -95,9 +96,12 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
 
         with torch.no_grad():
-            averaged_grads = tuple(
-                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
-            )
+            if grad_extra_tensors:
+                averaged_grads = grad_extra_tensors
+            else:
+                averaged_grads = tuple(
+                    grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+                )
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 17 - 3
hivemind/optim/optimizer.py

@@ -12,7 +12,7 @@ from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.optim.grad_averager import GradientAverager
-from hivemind.optim.experimental.power_ef_averager import PowerEFGradientAverager
+from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
@@ -194,7 +194,7 @@ class Optimizer(torch.optim.Optimizer):
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
-        grad_averager_opts: Optional[dict] = None,
+        grad_averager_opts: Optional[dict] = dict(),
         tracker_opts: Optional[dict] = None,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
@@ -244,6 +244,17 @@ class Optimizer(torch.optim.Optimizer):
         self.tracker = self._make_progress_tracker(
             target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
         )
+        if grad_rank_averager == "power_ef" and not use_local_updates:
+            assert len(extra_tensors) == 0
+            grad_extra_tensors = [
+                torch.zeros_like(param, device="cpu")
+                for param_group in optimizer.param_groups for param in param_group["params"]
+            ]
+            for tensor in grad_extra_tensors:
+                if tensor is not None:
+                    tensor.share_memory_()
+            grad_averager_opts["grad_extra_tensors"] = grad_extra_tensors
+            extra_tensors = [e for e in extra_tensors] + [eg for eg in grad_extra_tensors]
         self.state_averager = self._make_state_averager(
             optimizer=optimizer,
             params=params,
@@ -258,7 +269,10 @@ class Optimizer(torch.optim.Optimizer):
         )
         if not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, grad_rank_averager=grad_rank_averager, compression=grad_compression, **grad_averager_opts or {}
+                reuse_grad_buffers=reuse_grad_buffers,
+                grad_rank_averager=grad_rank_averager,
+                compression=grad_compression,
+                **grad_averager_opts or {}
             )
         else:
             self.grad_averager = None

+ 23 - 38
hivemind/optim/power_ef_averager.py

@@ -4,6 +4,7 @@ import faulthandler
 import math
 import torch
 import multiprocessing as mp
+import numpy as np
 
 from typing import Any, Iterable, Optional, Sequence
 
@@ -56,20 +57,23 @@ class PowerEFGradientAverager(GradientAverager):
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         warn: bool = True,
+        min_comprasion_ratio: float = 0.5,
+        grad_extra_tensors: Sequence[torch.Tensor] = (),
         **kwargs,
     ):
         self.rank = averager_rank
         self.parameters = tuple(parameters)
-        self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
-        self._gs = list(
-            torch.zeros_like(grad, device="cpu")
-            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
+        self._uncompressed_gradients = set(
+            i for i, grad in enumerate(self._grads_from_parameters())
+            if len(tuple(grad.size())) == 1 or 
+            (self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio)
         )
+        self._gradient_rests = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
         self._qs = list(
             torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
             for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
         )
-        for tensor in (self._qs + self._gs):
+        for tensor in (self._qs + self._gradient_rests):
             if tensor is not None:
                 assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
                 tensor.share_memory_()
@@ -82,6 +86,7 @@ class PowerEFGradientAverager(GradientAverager):
             accumulate_grads_on=accumulate_grads_on,
             client_mode=client_mode,
             warn=warn,
+            grad_extra_tensors=grad_extra_tensors,
             **kwargs
         )
 
@@ -104,16 +109,11 @@ class PowerEFGradientAverager(GradientAverager):
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
-        async def _dump_later():
-            await asyncio.sleep(15.0)
-            print([*map(asyncio.Task.print_stack, asyncio.Task.all_tasks())])
-        # task = asyncio.create_task(_dump_later())
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             modes = tuple(map(AveragingMode, mode_ids))
 
-            # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
             download_bandwidths = [
                 thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
             ]
@@ -123,14 +123,12 @@ class PowerEFGradientAverager(GradientAverager):
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 compressed_tensors = [lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
-                cs = [torch.zeros_like(grad, device="cpu") for grad in compressed_tensors]
-                for c, g, cg in zip(cs, self._gs, compressed_tensors):
-                    torch.sub(cg, g, out=c)
 
+                cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
                 ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
-                for p, q, c in zip(ps, self._qs, cs):
-                    torch.matmul(c.reshape(-1, q.size(0)), q, out=p)
-                first_all_reduced = ps + [lt for idx, lt in enumerate(local_tensors) if idx in self._uncompressed_gradients]
+                for p, q, rest in zip(ps, self._qs, cs):
+                    torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
+                first_all_reduced = ps + [rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients]
                 allreduce1 = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
@@ -152,7 +150,7 @@ class PowerEFGradientAverager(GradientAverager):
                 else:
                     async for _ in allreduce1:  # trigger all-reduce by iterating
                         raise ValueError("aux peers should not receive averaged tensors")
-                
+
                 # orth ps
                 for p in ps:
                     orthogonalize(p)
@@ -190,8 +188,8 @@ class PowerEFGradientAverager(GradientAverager):
                     new_c = torch.matmul(p, q.t())
                     c.copy_(new_c.reshape(c.size()))
 
-                for c, g in zip(cs, self._gs):
-                    torch.add(g, c, out=g)
+                for rest, lt in zip(self._gradient_rests, local_tensors):
+                    torch.add(lt, rest, out=lt)
 
                 return allreduce1.gathered
         except BaseException as e:
@@ -200,27 +198,14 @@ class PowerEFGradientAverager(GradientAverager):
         finally:
             pass
 
-    @contextlib.contextmanager
     @torch.no_grad()
-    def use_averaged_gradients(self):
-        self._new_averaged_grads = False
+    def load_accumulators_into_averager_(self):
+        """load locally accumulated gradients into the averager for aggregation"""
+        # divide locally accumulated gradients by the number of times they were accumulated
+        grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:
-            compressed_tensors = [lt for idx, lt in enumerate(averaged_grads) if idx not in self._uncompressed_gradients]
-            old_averaged = [torch.zeros_like(lt) for lt in compressed_tensors]
-            for g, cg, oag in zip(self._gs, compressed_tensors, old_averaged):
-                oag.copy_(cg)
-                cg.copy_(g)
-            try:
-                assert len(averaged_grads) == len(self.parameters)
-                old_grads = [param.grad for param in self.parameters]
-                for param, new_grad in zip(self.parameters, averaged_grads):
-                    param.grad.copy_(new_grad)
-                yield
-            finally:
-                for param, old_grad in zip(self.parameters, old_grads):
-                    param.grad.copy_(old_grad)
-            for cg, oag in zip(compressed_tensors, old_averaged):
-                cg.copy_(oag)
+            for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_rests):
+                torch.sub(grad_acc * grad_scale, averaged_grad, out=rest)
 
 
 @torch.jit.script