Artem Chumachenko 3 лет назад
Родитель
Сommit
10cce6ef23

+ 25 - 38
examples/playgroud_example.py

@@ -1,6 +1,7 @@
 import hivemind
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.power_ef_averager import PowerEFGradientAverager
+from hivemind.optim.experimental.power_sgd_averager import PowerSGDGradientAverager
 
 import faulthandler
 import torch
@@ -12,6 +13,7 @@ from torchvision.datasets import MNIST
 import multiprocessing as mp
 import threading
 import os
+import random
 import time
 
 
@@ -26,9 +28,6 @@ class Peer(threading.Thread):
         for param in self.model.parameters():
             param.grad = torch.zeros_like(param).share_memory_()
 
-        self.averager = PowerEFGradientAverager(
-            self.model.parameters(), 1, dht=self.dht, target_group_size=4, prefix='my_mega_exp', start=True,
-        )
         if start:
             self.start()
 
@@ -42,39 +41,33 @@ class Peer(threading.Thread):
 
         def data():
             while True:
-                train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=0, batch_size=1024, shuffle=True)
+                train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=0, batch_size=64, shuffle=True)
                 for batch in train_dataloader:
                     yield batch
         
-        opt = torch.optim.Adam(self.model.parameters(), lr=0.001)
-        
-        next_step_time = hivemind.get_dht_time() + 5
-        next_step_control = None
+        opt = hivemind.Optimizer(
+            dht=self.dht,
+            prefix="my_super_run",
+            params=self.model.parameters(),
+            optimizer=torch.optim.SGD,
+            lr=0.1,
+            train_batch_size=256,
+            batch_size=64
+        )
+        opt.load_state_from_peers()
+
         for i, (xb, yb) in enumerate(data()):
             logits = self.model(xb)
             loss = F.cross_entropy(logits, yb)
 
             loss.backward()
             torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
-            if next_step_control is None and (next_step_time - hivemind.get_dht_time() <= 1):
-                next_step_control = self.averager.schedule_step(scheduled_time=next_step_time)
             
-            self.averager.accumulate_grads_(batch_size=1024)
-
-            if hivemind.get_dht_time() >= next_step_time:
-                self.averager.step(control=next_step_control)
-                next_step_control.result()
-                with self.averager.use_averaged_gradients():
-                    with torch.no_grad():
-                        param = next(iter(self.model.parameters()))
-                        grad = param.grad.detach().cpu().norm().item()
-                        print_param = param.flatten()[-3:].detach().cpu().numpy()
-                        print(i, self.dht.peer_id.pretty()[-3:],f"{loss.item():.3f}", f"{hivemind.get_dht_time():.3f}", print_param, grad)
-                    opt.step()
-                self.averager.reset_accumulated_grads_()
-                next_step_time = hivemind.get_dht_time() + 5
-                next_step_control = None
-            if i > 10000: break
+            self.averager.accumulate_grads_(batch_size=64)
+
+            opt.step()
+            opt.zero_grad()
+            if i > 100000: break
 
 
 class SmallCNN(nn.Module):
@@ -82,19 +75,15 @@ class SmallCNN(nn.Module):
         super().__init__()
 
         self.features = nn.Sequential(
-            nn.Conv2d(1, 4, (5, 5)),
-            nn.ReLU(),
-            nn.Conv2d(4, 16, (5, 5)),
-            nn.ReLU(),
-            nn.Conv2d(16, 64, (5, 5)),
+            nn.Conv2d(1, 16, (9, 9)),
             nn.ReLU(),
-            nn.Conv2d(64, 64, (5, 5)),
+            nn.Conv2d(16, 16, (9, 9)),
             nn.ReLU(),
             nn.MaxPool2d(2)
         )
 
         self.cls = nn.Sequential(
-            nn.Linear(64 * 6 * 6, 400),
+            nn.Linear(16 * 6 * 6, 400),
             nn.ReLU(),
             nn.Linear(400, 10)
         )
@@ -108,12 +97,10 @@ if __name__ == "__main__":
     dht_root = hivemind.DHT(start=True)
 
     peers = [
-        Peer(0, start=False), Peer(1, start=False),
-        Peer(2, start=False), Peer(3, start=False)
+        Peer(i, start=False) for i in range(4)
     ]
-    peers[1].model.load_state_dict(peers[0].model.state_dict())
-    peers[2].model.load_state_dict(peers[0].model.state_dict())
-    peers[3].model.load_state_dict(peers[0].model.state_dict())
+    for i in range(1, 4):
+        peers[i].model.load_state_dict(peers[0].model.state_dict())
 
     for peer in peers:
         peer.start()

+ 1 - 2
hivemind/optim/experimental/power_ef_averager.py

@@ -191,7 +191,7 @@ class PowerEFGradientAverager(GradientAverager):
                     c.copy_(new_c.reshape(c.size()))
 
                 for c, g in zip(cs, self._gs):
-                    torch.add(g, c, out=g)
+                    torch.add(g, c * 0.9, out=g)
 
                 return allreduce1.gathered
         except BaseException as e:
@@ -199,7 +199,6 @@ class PowerEFGradientAverager(GradientAverager):
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
         finally:
             pass
-            # task.cancel()
 
     @contextlib.contextmanager
     @torch.no_grad()

+ 194 - 0
hivemind/optim/experimental/power_sgd_averager.py

@@ -0,0 +1,194 @@
+import asyncio
+import contextlib
+import faulthandler
+import math
+import torch
+import multiprocessing as mp
+
+from typing import Any, Iterable, Optional, Sequence
+
+import hivemind
+from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    NoCompression,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
+from hivemind.dht import DHT, DHTID
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.proto import averaging_pb2
+from hivemind.utils import MPFuture, TensorDescriptor, get_logger
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
+from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
+
+from .grad_averager import GradientAverager
+from .power_ef_averager import PowerEFGradientAverager, orthogonalize
+
+GatheredData = Any
+logger = get_logger(__name__)
+
+
+class PowerSGDGradientAverager(PowerEFGradientAverager):
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        rank: int,
+        *,
+        dht: hivemind.DHT,
+        prefix: str,
+        local_updates: bool = False,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        **kwargs,
+    ):
+        self.rank = rank
+        self.parameters = tuple(parameters)
+        self._local_updates = local_updates
+        self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
+        self._ms = list(
+            torch.zeros_like(grad, device=accumulate_grads_on)
+            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
+        )
+        self._gs = list(
+            torch.zeros_like(grad, device=accumulate_grads_on)
+            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
+        )
+        self._qs = list(
+            torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device=accumulate_grads_on)
+            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
+        )
+        for tensor in (self._qs + self._gs):
+            if tensor is not None:
+                assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
+                tensor.share_memory_()
+
+        super().__init__(
+            self.parameters,
+            rank=rank,
+            dht=dht,
+            prefix=prefix,
+            reuse_grad_buffers=reuse_grad_buffers,
+            accumulate_grads_on=accumulate_grads_on,
+            client_mode=client_mode,
+            warn=warn,
+            **kwargs
+        )
+
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+        async def _dump_later():
+            await asyncio.sleep(15.0)
+            print([*map(asyncio.Task.print_stack, asyncio.Task.all_tasks())])
+        # task = asyncio.create_task(_dump_later())
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+
+            # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
+                compressed_tensors = [lt for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
+                for m, cg in zip(self._ms, compressed_tensors):
+                    torch.sub(cg, m, out=m)
+
+                ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
+                local_ps = [p.detach() for p in ps]
+                for p, local_p, q, m in zip(ps, local_ps, self._qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
+                    local_p.copy_(p)
+                first_all_reduced = ps + [lt for idx, lt in enumerate(local_tensors) if idx in self._uncompressed_gradients]
+                allreduce1 = AllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id + b'.phase1',
+                    tensors=first_all_reduced,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    **kwargs,
+                )
+                self._running_groups[group_info.group_id + b'.phase1'].set_result(allreduce1)
+
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                else:
+                    async for _ in allreduce1:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
+                
+                # orth ps
+                for p in ps + local_ps:
+                    orthogonalize(p)
+
+                # compute qs
+                local_qs = [q.detach() for q in self._qs]
+                for p, local_p, q, local_q, m in zip(ps, local_ps, self._qs, local_qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), local_p, out=local_q)
+
+                allreduce2 = AllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id + b'.phase2',
+                    tensors=self._qs,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    **kwargs,
+                )
+                self._running_groups[group_info.group_id + b'.phase2'].set_result(allreduce2)
+
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updated = get_dht_time()
+                        self._state_updated.set()
+                else:
+                    async for _ in allreduce2:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
+
+                # recompute grads
+                for p, local_p, q, local_q, m, g in zip(ps, local_ps, self._qs, local_qs, self._ms, self._gs):
+                    new_g = torch.matmul(p, q.t()).reshape(g.size())
+                    g.copy_(new_g)
+                    sub_g = torch.matmul(local_p, local_q.t()).reshape(g.size()) if self._local_updates else new_g
+                    torch.sub(m, sub_g, out=m)
+
+                return allreduce1.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+        finally:
+            pass