فهرست منبع

Prepare GradScaler for hivemind.Optimizer (#413)

- Modified hivemind.GradScaler to make it compatible with hivemind.Optimizer (backwards-compatible)
- Changed TrainingStateAverager to be compatible with hivemind.GradScaler
- Made TrainingStateAverager.main_parameters and parameter_names public for use in optimizer

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 سال پیش
والد
کامیت
09e34f8366
3فایلهای تغییر یافته به همراه56 افزوده شده و 36 حذف شده
  1. 2 2
      hivemind/optim/collaborative.py
  2. 40 24
      hivemind/optim/experimental/state_averager.py
  3. 14 10
      hivemind/optim/grad_scaler.py

+ 2 - 2
hivemind/optim/collaborative.py

@@ -245,7 +245,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
@@ -310,7 +310,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self)
+                    assert grad_scaler.step(self.opt)
             else:
                 self.opt.step()
 

+ 40 - 24
hivemind/optim/experimental/state_averager.py

@@ -9,10 +9,10 @@ from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
 import torch
 
 import hivemind
-from hivemind import nested_compare
 from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import CompressionInfo, TensorRole
-from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 
@@ -100,7 +100,7 @@ class TrainingStateAverager(DecentralizedAverager):
         self.offload_optimizer = offload_optimizer
         self.custom_gradients = custom_gradients
 
-        self._main_parameters, self._parameter_names = main_parameters, parameter_names
+        self.main_parameters, self.parameter_names = main_parameters, parameter_names
         self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
@@ -197,7 +197,7 @@ class TrainingStateAverager(DecentralizedAverager):
             initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
             logger.log(
                 self.status_loglevel,
-                "Initializing optimizer manually since it has no tensors in state dict"
+                "Initializing optimizer manually since it has no tensors in state dict. "
                 "To override this, please provide initialize_optimizer=False",
             )
 
@@ -257,12 +257,12 @@ class TrainingStateAverager(DecentralizedAverager):
     def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
         """Get CompressionInfo for each state tensor, accounting for its role and specification"""
         tensor_infos = []
-        for param, param_name in zip(self._main_parameters, self._parameter_names):
+        for param, param_name in zip(self.main_parameters, self.parameter_names):
             tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
         for stats_name in self.opt_keys_for_averaging:
             opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(opt_parameters) == len(self._parameter_names)
-            for param, param_name in zip(opt_parameters, self._parameter_names):
+            assert len(opt_parameters) == len(self.parameter_names)
+            for param, param_name in zip(opt_parameters, self.parameter_names):
                 tensor_infos.append(
                     CompressionInfo.from_tensor(
                         self.optimizer.state[param][stats_name],
@@ -284,7 +284,8 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
-        averaging_kwargs: Optional[Dict[str, Any]] = None,
+        grad_scaler: Optional[GradScaler] = None,
+        averaging_opts: Optional[Dict[str, Any]] = None,
     ):
         """
         Perform one or several possible actions, depending on the specified keyword args.
@@ -298,9 +299,10 @@ class TrainingStateAverager(DecentralizedAverager):
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
-        :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
+        :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
@@ -312,8 +314,8 @@ class TrainingStateAverager(DecentralizedAverager):
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
-        if averaging_kwargs and not averaging_round:
-            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
+        if averaging_opts and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
         output = None
 
         if wait_for_delayed_update:
@@ -328,19 +330,17 @@ class TrainingStateAverager(DecentralizedAverager):
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
-                logger.log(self.status_loglevel, "Received results from background averaging round")
+                logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
 
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
                     self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Received results from background optimizer step")
+                logger.log(self.status_loglevel, "Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
 
         if increment_epoch:
             self.local_epoch += 1
-            logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
-            self._update_scheduler()
 
         if optimizer_step or zero_grad or averaging_round:
             assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
@@ -353,7 +353,8 @@ class TrainingStateAverager(DecentralizedAverager):
                 optimizer_step,
                 zero_grad,
                 averaging_round,
-                **averaging_kwargs or {},
+                grad_scaler,
+                **averaging_opts or {},
             )
 
             if (optimizer_step or zero_grad) and not delay_optimizer_step:
@@ -378,7 +379,9 @@ class TrainingStateAverager(DecentralizedAverager):
                     self.finished_optimizer_step.clear()
         return output
 
-    def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
+    def _do(
+        self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
+    ):
         """
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
@@ -386,12 +389,23 @@ class TrainingStateAverager(DecentralizedAverager):
         try:
             if optimizer_step:
                 logger.log(self.status_loglevel, f"Running optimizer step")
-                self.optimizer.step()
+                if grad_scaler is None:
+                    self.optimizer.step()
+                else:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.step(self.optimizer)
+
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.update()
+
+            self._update_scheduler()
+
             if zero_grad:
                 logger.log(self.status_loglevel, f"Running zero grad")
                 self.optimizer.zero_grad()
                 if self.offload_optimizer:
-                    for parameter in self._main_parameters:
+                    for parameter in self.main_parameters:
                         if parameter.grad is not None:
                             parameter.grad.zero_()
 
@@ -428,7 +442,7 @@ class TrainingStateAverager(DecentralizedAverager):
         """Copy local gradients into the gradient buffers of the offloaded optimizer"""
         assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-        for main_param, opt_param in zip(self._main_parameters, opt_parameters):
+        for main_param, opt_param in zip(self.main_parameters, opt_parameters):
             if main_param.grad is not None:
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
@@ -438,8 +452,10 @@ class TrainingStateAverager(DecentralizedAverager):
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         with self.lock_averaged_tensors:
             offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
-            for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
+            assert len(offloaded_parameters) == len(
+                self.main_parameters
+            ), "Optimizer parameters changed during training"
+            for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
                 main_param.copy_(offloaded_param, non_blocking=True)
 
     @torch.no_grad()
@@ -471,7 +487,7 @@ class TrainingStateAverager(DecentralizedAverager):
             )
             parameter_infos = [
                 CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
-                for param, key in zip(optimized_parameters, self._parameter_names)
+                for param, key in zip(optimized_parameters, self.parameter_names)
             ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             extra_infos = [
@@ -496,7 +512,7 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
+        parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
         num_parameters_and_extras = len(parameters_and_extras)
 
         loaded_state = super().load_state_from_peers(**kwargs)

+ 14 - 10
hivemind/optim/grad_scaler.py

@@ -4,7 +4,7 @@ from typing import Dict, Optional
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
-from torch.optim import Optimizer
+from torch.optim import Optimizer as TorchOptimizer
 
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils.logging import get_logger
@@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 
-class HivemindGradScaler(TorchGradScaler):
+class GradScaler(TorchGradScaler):
     """
     A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
@@ -33,7 +33,7 @@ class HivemindGradScaler(TorchGradScaler):
         finally:
             self._is_running_global_step = was_running
 
-    def unscale_(self, optimizer: Optimizer) -> bool:
+    def unscale_(self, optimizer: TorchOptimizer) -> bool:
         assert isinstance(optimizer, DecentralizedOptimizerBase)
         if self._is_running_global_step:
             super().unscale_(optimizer.opt)
@@ -43,11 +43,10 @@ class HivemindGradScaler(TorchGradScaler):
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
 
-    def step(self, optimizer: Optimizer, *args, **kwargs) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
+    def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
             if self.are_grads_finite(optimizer):
-                super().step(optimizer.opt, *args, **kwargs)
+                super().step(optimizer, *args, **kwargs)
             else:
                 logger.warning("Skipping global step due to gradient over/underflow")
             return True
@@ -72,12 +71,17 @@ class HivemindGradScaler(TorchGradScaler):
             return False
 
     def _unscale_grads_(
-        self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+        self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
     ) -> Dict[torch.device, torch.Tensor]:
         # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
-    def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())
+    def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
+        return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
+
+
+class HivemindGradScaler(GradScaler):
+    def __init__(self, *args, **kwargs):
+        logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
+        super().__init__(*args, **kwargs)