Forráskód Böngészése

Add gradient buffers to CollaborativeOptimizer (#220)

- Add an option to accumulate gradients locally (default = True, can be disabled to save some memory, not critical in tutorials)
- Fix a bug in CollaborativeOptimizer that caused it to implicitly scale all updates by self.local_steps_accumulated
- Fix some typos in description
justheuristic 4 éve
szülő
commit
2359906253
1 módosított fájl, 67 hozzáadás és 11 törlés
  1. 67 11
      hivemind/client/optim/collaborative.py

+ 67 - 11
hivemind/client/optim/collaborative.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 import warnings
 from dataclasses import dataclass
 from threading import Thread, Lock, Event
-from typing import Optional, Type
+from typing import Optional, Iterator
 import logging
 
 import torch
@@ -11,7 +11,7 @@ import numpy as np
 from hivemind.dht import DHT
 from hivemind.client.optim.base import DecentralizedOptimizerBase
 from hivemind.client.averaging.training import TrainingAverager
-from hivemind.utils import get_logger, get_dht_time, run_in_background, ValueWithExpiration
+from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 from hivemind.client.optim.performance_ema import PerformanceEMA
 
 logger = get_logger(__name__)
@@ -47,7 +47,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
 
-    - calling .step will periodially zero-out gradients w.r.t. model parameters after each step
+    - calling .step will periodically zero-out gradients w.r.t. model parameters after each step
     - it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
 
     :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
@@ -55,7 +55,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
     :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param target_group_size: maximum group size for DecentralizedAverager's all-reduce
     :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
     :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
     :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
@@ -69,6 +68,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+     device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+     the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
+    :param kwargs: additional parameters forwarded to DecentralizedAverager
     :note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
     """
@@ -78,14 +83,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                  min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
                  expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
                  metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
-                 **kwargs):
+                 reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None, **kwargs):
         super().__init__(opt, dht)
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         self.prefix, self.scheduler = prefix, scheduler
         self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
         self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
             min_refresh_period, max_refresh_period, default_refresh_period
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
@@ -134,9 +142,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         :param batch_size: optional override for batch_size_per_step from init
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
-        if batch_size is not None and self.batch_size_per_step is None:
-            raise ValueError("Please either set batch_size_per_step parameter at init or provide batch_size in .step")
-        batch_size = self.batch_size_per_step if batch_size is None else batch_size
+        if self.batch_size_per_step is None:
+            if batch_size is None:
+                raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+            logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
+            self.batch_size_per_step = batch_size
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
         if not self.is_synchronized:
             self.load_state_from_peers()
@@ -146,6 +157,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
                            f"but metadata expired in {self.metadata_expiration} s.")
 
+        self.accumulate_grads_(batch_size)
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
@@ -164,6 +176,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
 
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
+            self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
+
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
@@ -176,6 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             self.opt.step()
             self.opt.zero_grad()
+            self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
             self.collaboration_state_updated.set()
@@ -184,6 +200,46 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             logger.log(self.status_loglevel, f"Optimizer step: done!")
             return output
 
+    def _grad_buffers(self) -> Iterator[torch.Tensor]:
+        """ pytorch-internal gradient buffers """
+        for param_group in self.opt.param_groups:
+            for param in param_group['params']:
+                if param.grad is None:
+                    yield torch.zeros_like(param)
+                else:
+                    yield param.grad
+
+    @torch.no_grad()
+    def accumulated_grads(self) -> Iterator[torch.Tensor]:
+        """ local gradient accumulators """
+        if self.reuse_grad_buffers:
+            yield from self._grad_buffers()
+        elif self._grads is None:
+            with torch.no_grad():
+                self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
+        yield from self._grads
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """ add current gradients to grad accumulators (if any) """
+        if self.reuse_grad_buffers:
+            return  # user is responsible for accumulating gradients in .grad buffers
+        alpha = float(batch_size) / self.batch_size_per_step
+        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+            grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    @torch.no_grad()
+    def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
+        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+            grad_buf[...] = grad_acc.to(grad_buf.device)
+            if scale_by is not None:
+                grad_buf.mul_(scale_by)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        for grad_buf in self._grad_buffers():
+            grad_buf.zero_()
+
     def report_training_progress(self):
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """
         while self.is_alive():
@@ -235,17 +291,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             if not is_client:
                 global_optimizer_step = max(global_optimizer_step, opt_step)
 
-        total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0
+        total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
 
         for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
             total_samples_per_second += samples_per_second
             if opt_step == global_optimizer_step:
                 total_samples_accumulated += samples_accumulated
-                estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
+                estimated_current_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
 
-        estimated_samples_remaining = self.target_batch_size - estimated_curent_samples
+        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
         estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
 
         expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))