瀏覽代碼

Prepare GradScaler for hivemind.Optimizer (#413)

- Modified hivemind.GradScaler to make it compatible with hivemind.Optimizer (backwards-compatible)
- Changed TrainingStateAverager to be compatible with hivemind.GradScaler
- Made TrainingStateAverager.main_parameters and parameter_names public for use in optimizer

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 年之前
父節點
當前提交
09e34f8366
共有 3 個文件被更改,包括 56 次插入36 次删除
  1. 2 2
      hivemind/optim/collaborative.py
  2. 40 24
      hivemind/optim/experimental/state_averager.py
  3. 14 10
      hivemind/optim/grad_scaler.py

+ 2 - 2
hivemind/optim/collaborative.py

@@ -245,7 +245,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
@@ -310,7 +310,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
             if grad_scaler is not None:
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
                 with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self)
+                    assert grad_scaler.step(self.opt)
             else:
             else:
                 self.opt.step()
                 self.opt.step()
 
 

+ 40 - 24
hivemind/optim/experimental/state_averager.py

@@ -9,10 +9,10 @@ from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind import nested_compare
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.compression import CompressionInfo, TensorRole
-from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -100,7 +100,7 @@ class TrainingStateAverager(DecentralizedAverager):
         self.offload_optimizer = offload_optimizer
         self.offload_optimizer = offload_optimizer
         self.custom_gradients = custom_gradients
         self.custom_gradients = custom_gradients
 
 
-        self._main_parameters, self._parameter_names = main_parameters, parameter_names
+        self.main_parameters, self.parameter_names = main_parameters, parameter_names
         self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
         self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
         self.optimizer, self.scheduler = self._init_components(
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
             param_groups, optimizer, scheduler, initialize_optimizer
@@ -197,7 +197,7 @@ class TrainingStateAverager(DecentralizedAverager):
             initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
             initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
             logger.log(
             logger.log(
                 self.status_loglevel,
                 self.status_loglevel,
-                "Initializing optimizer manually since it has no tensors in state dict"
+                "Initializing optimizer manually since it has no tensors in state dict. "
                 "To override this, please provide initialize_optimizer=False",
                 "To override this, please provide initialize_optimizer=False",
             )
             )
 
 
@@ -257,12 +257,12 @@ class TrainingStateAverager(DecentralizedAverager):
     def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
     def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
         """Get CompressionInfo for each state tensor, accounting for its role and specification"""
         """Get CompressionInfo for each state tensor, accounting for its role and specification"""
         tensor_infos = []
         tensor_infos = []
-        for param, param_name in zip(self._main_parameters, self._parameter_names):
+        for param, param_name in zip(self.main_parameters, self.parameter_names):
             tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
             tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
         for stats_name in self.opt_keys_for_averaging:
         for stats_name in self.opt_keys_for_averaging:
             opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
             opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(opt_parameters) == len(self._parameter_names)
-            for param, param_name in zip(opt_parameters, self._parameter_names):
+            assert len(opt_parameters) == len(self.parameter_names)
+            for param, param_name in zip(opt_parameters, self.parameter_names):
                 tensor_infos.append(
                 tensor_infos.append(
                     CompressionInfo.from_tensor(
                     CompressionInfo.from_tensor(
                         self.optimizer.state[param][stats_name],
                         self.optimizer.state[param][stats_name],
@@ -284,7 +284,8 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
         delay_averaging: Optional[bool] = None,
-        averaging_kwargs: Optional[Dict[str, Any]] = None,
+        grad_scaler: Optional[GradScaler] = None,
+        averaging_opts: Optional[Dict[str, Any]] = None,
     ):
     ):
         """
         """
         Perform one or several possible actions, depending on the specified keyword args.
         Perform one or several possible actions, depending on the specified keyword args.
@@ -298,9 +299,10 @@ class TrainingStateAverager(DecentralizedAverager):
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
-        :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
+        :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         """
         if delay_averaging is None:
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
             delay_averaging = delay_optimizer_step
@@ -312,8 +314,8 @@ class TrainingStateAverager(DecentralizedAverager):
         if delay_optimizer_step:
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
-        if averaging_kwargs and not averaging_round:
-            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
+        if averaging_opts and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
         output = None
         output = None
 
 
         if wait_for_delayed_update:
         if wait_for_delayed_update:
@@ -328,19 +330,17 @@ class TrainingStateAverager(DecentralizedAverager):
             if self.finished_averaging_round.is_set():
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
                     self._apply_averaging_results_()
-                logger.log(self.status_loglevel, "Received results from background averaging round")
+                logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
                 self.finished_averaging_round.clear()
 
 
             if self.finished_optimizer_step.is_set():
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
                 if self.offload_optimizer:
                     self._apply_optimizer_results_()
                     self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Received results from background optimizer step")
+                logger.log(self.status_loglevel, "Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
                 self.finished_optimizer_step.clear()
 
 
         if increment_epoch:
         if increment_epoch:
             self.local_epoch += 1
             self.local_epoch += 1
-            logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
-            self._update_scheduler()
 
 
         if optimizer_step or zero_grad or averaging_round:
         if optimizer_step or zero_grad or averaging_round:
             assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
             assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
@@ -353,7 +353,8 @@ class TrainingStateAverager(DecentralizedAverager):
                 optimizer_step,
                 optimizer_step,
                 zero_grad,
                 zero_grad,
                 averaging_round,
                 averaging_round,
-                **averaging_kwargs or {},
+                grad_scaler,
+                **averaging_opts or {},
             )
             )
 
 
             if (optimizer_step or zero_grad) and not delay_optimizer_step:
             if (optimizer_step or zero_grad) and not delay_optimizer_step:
@@ -378,7 +379,9 @@ class TrainingStateAverager(DecentralizedAverager):
                     self.finished_optimizer_step.clear()
                     self.finished_optimizer_step.clear()
         return output
         return output
 
 
-    def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
+    def _do(
+        self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
+    ):
         """
         """
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         This method is meant to be called in the background executor.
@@ -386,12 +389,23 @@ class TrainingStateAverager(DecentralizedAverager):
         try:
         try:
             if optimizer_step:
             if optimizer_step:
                 logger.log(self.status_loglevel, f"Running optimizer step")
                 logger.log(self.status_loglevel, f"Running optimizer step")
-                self.optimizer.step()
+                if grad_scaler is None:
+                    self.optimizer.step()
+                else:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.step(self.optimizer)
+
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.update()
+
+            self._update_scheduler()
+
             if zero_grad:
             if zero_grad:
                 logger.log(self.status_loglevel, f"Running zero grad")
                 logger.log(self.status_loglevel, f"Running zero grad")
                 self.optimizer.zero_grad()
                 self.optimizer.zero_grad()
                 if self.offload_optimizer:
                 if self.offload_optimizer:
-                    for parameter in self._main_parameters:
+                    for parameter in self.main_parameters:
                         if parameter.grad is not None:
                         if parameter.grad is not None:
                             parameter.grad.zero_()
                             parameter.grad.zero_()
 
 
@@ -428,7 +442,7 @@ class TrainingStateAverager(DecentralizedAverager):
         """Copy local gradients into the gradient buffers of the offloaded optimizer"""
         """Copy local gradients into the gradient buffers of the offloaded optimizer"""
         assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
         assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-        for main_param, opt_param in zip(self._main_parameters, opt_parameters):
+        for main_param, opt_param in zip(self.main_parameters, opt_parameters):
             if main_param.grad is not None:
             if main_param.grad is not None:
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
 
@@ -438,8 +452,10 @@ class TrainingStateAverager(DecentralizedAverager):
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
             offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
-            for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
+            assert len(offloaded_parameters) == len(
+                self.main_parameters
+            ), "Optimizer parameters changed during training"
+            for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
                 main_param.copy_(offloaded_param, non_blocking=True)
                 main_param.copy_(offloaded_param, non_blocking=True)
 
 
     @torch.no_grad()
     @torch.no_grad()
@@ -471,7 +487,7 @@ class TrainingStateAverager(DecentralizedAverager):
             )
             )
             parameter_infos = [
             parameter_infos = [
                 CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
                 CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
-                for param, key in zip(optimized_parameters, self._parameter_names)
+                for param, key in zip(optimized_parameters, self.parameter_names)
             ]
             ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             extra_infos = [
             extra_infos = [
@@ -496,7 +512,7 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         :returns: whether or the averager succeeded in loading parameters
         """
         """
-        parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
+        parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
         num_parameters_and_extras = len(parameters_and_extras)
         num_parameters_and_extras = len(parameters_and_extras)
 
 
         loaded_state = super().load_state_from_peers(**kwargs)
         loaded_state = super().load_state_from_peers(**kwargs)

+ 14 - 10
hivemind/optim/grad_scaler.py

@@ -4,7 +4,7 @@ from typing import Dict, Optional
 import torch
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
 from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
-from torch.optim import Optimizer
+from torch.optim import Optimizer as TorchOptimizer
 
 
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class HivemindGradScaler(TorchGradScaler):
+class GradScaler(TorchGradScaler):
     """
     """
     A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
     A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
@@ -33,7 +33,7 @@ class HivemindGradScaler(TorchGradScaler):
         finally:
         finally:
             self._is_running_global_step = was_running
             self._is_running_global_step = was_running
 
 
-    def unscale_(self, optimizer: Optimizer) -> bool:
+    def unscale_(self, optimizer: TorchOptimizer) -> bool:
         assert isinstance(optimizer, DecentralizedOptimizerBase)
         assert isinstance(optimizer, DecentralizedOptimizerBase)
         if self._is_running_global_step:
         if self._is_running_global_step:
             super().unscale_(optimizer.opt)
             super().unscale_(optimizer.opt)
@@ -43,11 +43,10 @@ class HivemindGradScaler(TorchGradScaler):
             self._optimizer_states_to_reset.add(id(optimizer))
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
             return False
 
 
-    def step(self, optimizer: Optimizer, *args, **kwargs) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
+    def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
         if self._is_running_global_step:
             if self.are_grads_finite(optimizer):
             if self.are_grads_finite(optimizer):
-                super().step(optimizer.opt, *args, **kwargs)
+                super().step(optimizer, *args, **kwargs)
             else:
             else:
                 logger.warning("Skipping global step due to gradient over/underflow")
                 logger.warning("Skipping global step due to gradient over/underflow")
             return True
             return True
@@ -72,12 +71,17 @@ class HivemindGradScaler(TorchGradScaler):
             return False
             return False
 
 
     def _unscale_grads_(
     def _unscale_grads_(
-        self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+        self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
     ) -> Dict[torch.device, torch.Tensor]:
     ) -> Dict[torch.device, torch.Tensor]:
         # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
         # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
 
-    def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())
+    def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
+        return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
+
+
+class HivemindGradScaler(GradScaler):
+    def __init__(self, *args, **kwargs):
+        logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
+        super().__init__(*args, **kwargs)