Browse Source

lessons from sahajbert-xl

justheuristic 3 years ago
parent
commit
ea127d9ee2
2 changed files with 122 additions and 24 deletions
  1. 48 24
      hivemind/optim/collaborative.py
  2. 74 0
      hivemind/optim/grad_scaler.py

+ 48 - 24
hivemind/optim/collaborative.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import logging
 import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
 from threading import Event, Lock, Thread
 from threading import Event, Lock, Thread
-from typing import Dict, Iterator, Optional
+from typing import Dict, Iterator, Optional, Any, Callable
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -15,7 +15,9 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import Endpoint, get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger
+
+from lib.staging.scaler import HivemindGradScaler
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
@@ -97,6 +99,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
     :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
     """
     """
+    _step_supports_amp_scaling = True  # pytorch amp support
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -197,6 +200,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 try:
                 try:
                     self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                     break
+                except KeyboardInterrupt:
+                    raise
                 except BaseException as e:
                 except BaseException as e:
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
                     continue
@@ -205,13 +210,15 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.update_scheduler()
             self.update_scheduler()
 
 
-    def step(self, batch_size: Optional[int] = None, **kwargs):
+    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
 
 
         :param batch_size: optional override for batch_size_per_step from init
         :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
         """
+        assert grad_scaler is None or isinstance(grad_scaler, HivemindGradScaler)
         if self.batch_size_per_step is None:
         if self.batch_size_per_step is None:
             if batch_size is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -227,6 +234,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
 
+        if grad_scaler and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
+            self.should_report_progress.set()
+            return
+
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
             logger.warning(
                 f"Training step took {get_dht_time() - self.last_step_time}, "
                 f"Training step took {get_dht_time() - self.last_step_time}, "
@@ -251,6 +265,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
             self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
+            if grad_scaler:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.unscale_(self)
+
             current_step, group_info = self.averager.local_step, None
             current_step, group_info = self.averager.local_step, None
 
 
             if self.collaboration_state.num_peers > 1:
             if self.collaboration_state.num_peers > 1:
@@ -279,13 +297,21 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
                     f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
                 )
                 )
 
 
-            self.opt.step()
+            if grad_scaler:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.step(self)
+            else:
+                self.opt.step()
+
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.collaboration_state_updated.set()
             self.update_scheduler()
             self.update_scheduler()
+            if grad_scaler:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.update()
 
 
         logger.log(self.status_loglevel, f"Optimizer step: done!")
         logger.log(self.status_loglevel, f"Optimizer step: done!")
 
 
@@ -344,38 +370,36 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         """local gradient accumulators"""
         """local gradient accumulators"""
         if self.reuse_grad_buffers:
         if self.reuse_grad_buffers:
             yield from self._grad_buffers()
             yield from self._grad_buffers()
-        elif self._grads is None:
-            with torch.no_grad():
-                self._grads = [
-                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
-                ]
-        yield from self._grads
+            return
+        else:
+            if self._grads is None:
+                self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
+            yield from self._grads
 
 
     @torch.no_grad()
     @torch.no_grad()
     def accumulate_grads_(self, batch_size: int):
     def accumulate_grads_(self, batch_size: int):
         """add current gradients to grad accumulators (if any)"""
         """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
         if self.reuse_grad_buffers:
-            return  # user is responsible for accumulating gradients in .grad buffers
-        alpha = float(batch_size) / self.batch_size_per_step
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+            # user is responsible for accumulating gradients in .grad buffers
+            assert batch_size == self.batch_size_per_step, "Custom batch size is not implemented for reuse_grad_buffers"
+        else:
+            alpha = float(batch_size) / self.batch_size_per_step
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
 
 
     @torch.no_grad()
     @torch.no_grad()
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if self.reuse_grad_buffers:
-            return
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_buf[...] = grad_acc.to(grad_buf.device)
-            if scale_by is not None:
+        if not self.reuse_grad_buffers:
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
+        if scale_by is not None:
+            for grad_buf in self._grad_buffers():
                 grad_buf.mul_(scale_by)
                 grad_buf.mul_(scale_by)
 
 
     @torch.no_grad()
     @torch.no_grad()
     def reset_accumulated_grads_(self):
     def reset_accumulated_grads_(self):
-        if self.reuse_grad_buffers:
-            self.opt.zero_grad()
-        else:
-            for grad_buf in self.accumulated_grads():
-                grad_buf.zero_()
+        for grad_buf in self.accumulated_grads():
+            grad_buf.zero_()
 
 
     def report_training_progress(self):
     def report_training_progress(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next step"""
         """Periodically publish metadata and the current number of samples accumulated towards the next step"""

+ 74 - 0
hivemind/optim/grad_scaler.py

@@ -0,0 +1,74 @@
+import contextlib
+from typing import Dict
+
+import torch
+from hivemind import DecentralizedOptimizerBase, get_logger
+from torch.cuda.amp import GradScaler as TorchGradScaler
+from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
+from torch.optim import Optimizer
+
+
+logger = get_logger(__name__)
+
+
+class HivemindGradScaler(TorchGradScaler):
+    """A thin wrapper over GradScaler that supports hivemind-style training with CollaborativeOptimizer and others"""
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._is_running_global_step = False
+        self._optimizer_states_to_reset = set()
+
+    @contextlib.contextmanager
+    def running_global_step(self):
+        was_running, self._is_running_global_step = self._is_running_global_step, True
+        try:
+            yield
+        finally:
+            self._is_running_global_step = was_running
+
+    def unscale_(self, optimizer, actually_unscale: bool = False):
+        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        if self._is_running_global_step:
+            super().unscale_(optimizer.opt)
+            return True
+        else:
+            self._check_inf_per_device(optimizer.opt)
+            self._optimizer_states_to_reset.add(id(optimizer))
+            return False
+
+    def step(self, optimizer, *args, **kwargs):
+        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        if self._is_running_global_step:
+            if self.are_grads_finite(optimizer):
+                super().step(optimizer.opt, *args, **kwargs)
+            else:
+                logger.warning("Skipping global step due to gradient over/underflow")
+            return True
+        else:
+            super().step(optimizer)
+            self._optimizer_states_to_reset.add(optimizer)
+            return False
+
+    def update(self, new_scale=None):
+        total_infs = 0
+        for optimizer_state in self._per_optimizer_states.values():
+            total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+        if self._is_running_global_step or total_infs != 0:
+            # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+            super().update(new_scale)
+            return True
+        else:
+            for opt_id in self._optimizer_states_to_reset:
+                self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+            self._optimizer_states_to_reset.clear()
+            return False
+
+    def _unscale_grads_(
+        self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+    ) -> Dict[torch.device, torch.Tensor]:
+        return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
+
+    def are_grads_finite(self, optimizer: DecentralizedOptimizerBase):
+        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())