Browse Source

Support different AMP & buffer configurations in one experiment, fix minor bugs (#389)

New features:
- CollaborativeOptimizer can now combine fp16=True and reuse_grad_buffers=True with a special scaler
- CollaborativeOptimizer peers with reuse_grad_buffers=True and reuse_grad_buffers=False can now co-exist
- CollaborativeOptimizer peers with and without AMP can now co-exist


The new behavior of CollaborativeOptimizer with fp16 is:
* grad_scaler=None: regular fp32 behavior
* reuse_grad_buffers=False with GradScaler: works as usual, independently un-scales each tensor before accumulation, does not affect internal optimizer
* reuse_grad_buffers=True with GradScaler: when calling scaler.step(opt), it will raise error and complain that it requires HivemindGradScaler
* reuse_grad_buffers=False with HivemindGradScaler: applies unscale/update only around global optimizer step

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 3 years ago
parent
commit
1d862c9a5d
3 changed files with 131 additions and 22 deletions
  1. 1 0
      hivemind/optim/__init__.py
  2. 47 22
      hivemind/optim/collaborative.py
  3. 83 0
      hivemind/optim/grad_scaler.py

+ 1 - 0
hivemind/optim/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 47 - 22
hivemind/optim/collaborative.py

@@ -14,8 +14,9 @@ from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import Endpoint, get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
@@ -147,6 +148,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
+        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
+
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
@@ -197,6 +200,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 try:
                     self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
+                except KeyboardInterrupt:
+                    raise
                 except BaseException as e:
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
@@ -205,13 +210,16 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
-    def step(self, batch_size: Optional[int] = None, **kwargs):
+    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
 
         :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
+        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
         if self.batch_size_per_step is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -227,6 +235,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
+            self.should_report_progress.set()
+            return
+
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
                 f"Training step took {get_dht_time() - self.last_step_time}, "
@@ -251,6 +266,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.unscale_(self)
+
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
@@ -279,13 +298,21 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
                 )
 
-            self.opt.step()
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.step(self)
+            else:
+                self.opt.step()
+
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.update_scheduler()
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.update()
 
         logger.log(self.status_loglevel, f"Optimizer step: done!")
 
@@ -344,38 +371,36 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         """local gradient accumulators"""
         if self.reuse_grad_buffers:
             yield from self._grad_buffers()
-        elif self._grads is None:
-            with torch.no_grad():
-                self._grads = [
-                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
-                ]
+            return
+
+        if self._grads is None:
+            self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
         yield from self._grads
 
     @torch.no_grad()
     def accumulate_grads_(self, batch_size: int):
         """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
-            return  # user is responsible for accumulating gradients in .grad buffers
-        alpha = float(batch_size) / self.batch_size_per_step
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+            # user is responsible for accumulating gradients in .grad buffers
+            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
+        else:
+            alpha = float(batch_size) / self.batch_size_per_step
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
 
     @torch.no_grad()
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if self.reuse_grad_buffers:
-            return
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_buf[...] = grad_acc.to(grad_buf.device)
-            if scale_by is not None:
+        if not self.reuse_grad_buffers:
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
+        if scale_by is not None:
+            for grad_buf in self._grad_buffers():
                 grad_buf.mul_(scale_by)
 
     @torch.no_grad()
     def reset_accumulated_grads_(self):
-        if self.reuse_grad_buffers:
-            self.opt.zero_grad()
-        else:
-            for grad_buf in self.accumulated_grads():
-                grad_buf.zero_()
+        for grad_buf in self.accumulated_grads():
+            grad_buf.zero_()
 
     def report_training_progress(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next step"""

+ 83 - 0
hivemind/optim/grad_scaler.py

@@ -0,0 +1,83 @@
+import contextlib
+from typing import Dict, Optional
+
+import torch
+from torch.cuda.amp import GradScaler as TorchGradScaler
+from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
+from torch.optim import Optimizer
+
+from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class HivemindGradScaler(TorchGradScaler):
+    """
+    A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
+    - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
+    - limit increasing gradient scale to only immediately after global optimizer steps
+    - allow training with some or all master parameters in fp16
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._is_running_global_step = False
+        self._optimizer_states_to_reset = set()
+
+    @contextlib.contextmanager
+    def running_global_step(self):
+        was_running, self._is_running_global_step = self._is_running_global_step, True
+        try:
+            yield
+        finally:
+            self._is_running_global_step = was_running
+
+    def unscale_(self, optimizer: Optimizer) -> bool:
+        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        if self._is_running_global_step:
+            super().unscale_(optimizer.opt)
+            return True
+        else:
+            self._check_inf_per_device(optimizer.opt)
+            self._optimizer_states_to_reset.add(id(optimizer))
+            return False
+
+    def step(self, optimizer: Optimizer, *args, **kwargs) -> bool:
+        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        if self._is_running_global_step:
+            if self.are_grads_finite(optimizer):
+                super().step(optimizer.opt, *args, **kwargs)
+            else:
+                logger.warning("Skipping global step due to gradient over/underflow")
+            return True
+        else:
+            super().step(optimizer)
+            self._optimizer_states_to_reset.add(id(optimizer))
+            return False
+
+    def update(self, new_scale: Optional[float] = None) -> bool:
+        total_infs = 0
+        for optimizer_state in self._per_optimizer_states.values():
+            total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+        if self._is_running_global_step or total_infs != 0:
+            # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+            super().update(new_scale)
+            return True
+        else:
+            for opt_id in self._optimizer_states_to_reset:
+                self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+            self._optimizer_states_to_reset.clear()
+            return False
+
+    def _unscale_grads_(
+        self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+    ) -> Dict[torch.device, torch.Tensor]:
+        # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
+        # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
+        return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
+
+    def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool:
+        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())