Browse Source

Add gradient buffers to CollaborativeOptimizer (#220)

- Add an option to accumulate gradients locally (default = True, can be disabled to save some memory, not critical in tutorials)
- Fix a bug in CollaborativeOptimizer that caused it to implicitly scale all updates by self.local_steps_accumulated
- Fix some typos in description
justheuristic 4 years ago
parent
commit
2359906253
1 changed files with 67 additions and 11 deletions
  1. 67 11
      hivemind/client/optim/collaborative.py

+ 67 - 11
hivemind/client/optim/collaborative.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 import warnings
 import warnings
 from dataclasses import dataclass
 from dataclasses import dataclass
 from threading import Thread, Lock, Event
 from threading import Thread, Lock, Event
-from typing import Optional, Type
+from typing import Optional, Iterator
 import logging
 import logging
 
 
 import torch
 import torch
@@ -11,7 +11,7 @@ import numpy as np
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.client.optim.base import DecentralizedOptimizerBase
 from hivemind.client.optim.base import DecentralizedOptimizerBase
 from hivemind.client.averaging.training import TrainingAverager
 from hivemind.client.averaging.training import TrainingAverager
-from hivemind.utils import get_logger, get_dht_time, run_in_background, ValueWithExpiration
+from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 from hivemind.client.optim.performance_ema import PerformanceEMA
 from hivemind.client.optim.performance_ema import PerformanceEMA
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -47,7 +47,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
 
 
-    - calling .step will periodially zero-out gradients w.r.t. model parameters after each step
+    - calling .step will periodically zero-out gradients w.r.t. model parameters after each step
     - it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
     - it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
 
 
     :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
     :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
@@ -55,7 +55,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
     :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
     :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
     :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param target_group_size: maximum group size for DecentralizedAverager's all-reduce
     :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
     :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
     :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
     :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
     :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
     :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
@@ -69,6 +68,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+     device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+     the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
+    :param kwargs: additional parameters forwarded to DecentralizedAverager
     :note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
     :note: if you are using CollaborativeOptimizer with a lr_scheduler, it is recommended to pass this scheduler
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
     """
     """
@@ -78,14 +83,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                  min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
                  min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
                  expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
                  expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
                  metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
                  metadata_expiration: float = 30.0, averaging_timeout: Optional[float] = None, verbose: bool = False,
-                 **kwargs):
+                 reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None, **kwargs):
         super().__init__(opt, dht)
         super().__init__(opt, dht)
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         self.prefix, self.scheduler = prefix, scheduler
         self.prefix, self.scheduler = prefix, scheduler
         self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
         self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
         self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
         self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
             min_refresh_period, max_refresh_period, default_refresh_period
             min_refresh_period, max_refresh_period, default_refresh_period
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
         self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
         self.averager = self._make_averager(**kwargs)
 
 
@@ -134,9 +142,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         :param batch_size: optional override for batch_size_per_step from init
         :param batch_size: optional override for batch_size_per_step from init
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
         """
-        if batch_size is not None and self.batch_size_per_step is None:
-            raise ValueError("Please either set batch_size_per_step parameter at init or provide batch_size in .step")
-        batch_size = self.batch_size_per_step if batch_size is None else batch_size
+        if self.batch_size_per_step is None:
+            if batch_size is None:
+                raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+            logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
+            self.batch_size_per_step = batch_size
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
 
         if not self.is_synchronized:
         if not self.is_synchronized:
             self.load_state_from_peers()
             self.load_state_from_peers()
@@ -146,6 +157,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
             logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
                            f"but metadata expired in {self.metadata_expiration} s.")
                            f"but metadata expired in {self.metadata_expiration} s.")
 
 
+        self.accumulate_grads_(batch_size)
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
             self.local_steps_accumulated += 1
@@ -164,6 +176,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
             return
 
 
         with self.performance_ema.pause(), self.lock_collaboration_state:
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
+            self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
+
             if self.collaboration_state.num_peers > 1:
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 weight = self.local_samples_accumulated / mean_samples_per_worker
@@ -176,6 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
             self.opt.step()
             self.opt.step()
             self.opt.zero_grad()
             self.opt.zero_grad()
+            self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
             self.collaboration_state.register_step()
             self.collaboration_state_updated.set()
             self.collaboration_state_updated.set()
@@ -184,6 +200,46 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             logger.log(self.status_loglevel, f"Optimizer step: done!")
             logger.log(self.status_loglevel, f"Optimizer step: done!")
             return output
             return output
 
 
+    def _grad_buffers(self) -> Iterator[torch.Tensor]:
+        """ pytorch-internal gradient buffers """
+        for param_group in self.opt.param_groups:
+            for param in param_group['params']:
+                if param.grad is None:
+                    yield torch.zeros_like(param)
+                else:
+                    yield param.grad
+
+    @torch.no_grad()
+    def accumulated_grads(self) -> Iterator[torch.Tensor]:
+        """ local gradient accumulators """
+        if self.reuse_grad_buffers:
+            yield from self._grad_buffers()
+        elif self._grads is None:
+            with torch.no_grad():
+                self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
+        yield from self._grads
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """ add current gradients to grad accumulators (if any) """
+        if self.reuse_grad_buffers:
+            return  # user is responsible for accumulating gradients in .grad buffers
+        alpha = float(batch_size) / self.batch_size_per_step
+        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+            grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    @torch.no_grad()
+    def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
+        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+            grad_buf[...] = grad_acc.to(grad_buf.device)
+            if scale_by is not None:
+                grad_buf.mul_(scale_by)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        for grad_buf in self._grad_buffers():
+            grad_buf.zero_()
+
     def report_training_progress(self):
     def report_training_progress(self):
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """
         while self.is_alive():
         while self.is_alive():
@@ -235,17 +291,17 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             if not is_client:
             if not is_client:
                 global_optimizer_step = max(global_optimizer_step, opt_step)
                 global_optimizer_step = max(global_optimizer_step, opt_step)
 
 
-        total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0
+        total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
 
 
         for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
         for opt_step, samples_accumulated, samples_per_second, timestep, is_client in valid_peer_states:
             total_samples_per_second += samples_per_second
             total_samples_per_second += samples_per_second
             if opt_step == global_optimizer_step:
             if opt_step == global_optimizer_step:
                 total_samples_accumulated += samples_accumulated
                 total_samples_accumulated += samples_accumulated
-                estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
+                estimated_current_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
 
 
-        estimated_samples_remaining = self.target_batch_size - estimated_curent_samples
+        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
         estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
         estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
 
 
         expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
         expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))