Prechádzať zdrojové kódy

Support PyTorch 2.0.0 (#559)

- Fix LRSchedulerBase
- Handle None after .zero_grad() in torch 2.0.0
- Use set_to_none=True by default in torch>=2.0
- Add set_to_none param to TrainingStateAverager.step()

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
(cherry picked from commit 98531ce9aca3f082eea44532aca1df4630600dd1)
justheuristic 2 rokov pred
rodič
commit
6a21a7306b

+ 1 - 2
examples/albert/run_trainer.py

@@ -18,6 +18,7 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
 from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
+from hivemind.optim.state_averager import LRSchedulerBase
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.networking import log_visible_maddrs
 
@@ -33,8 +34,6 @@ from arguments import (
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
-
 
 def setup_transformers_logging(process_rank: int):
     if is_main_process(process_rank):

+ 5 - 1
hivemind/moe/server/module_backend.py

@@ -8,9 +8,13 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 logger = get_logger(__name__)
 
+try:
+    LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
+except AttributeError:  # torch < 2.0.0
+    LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
+
 
 class ModuleBackend:
     """

+ 8 - 6
hivemind/optim/optimizer.py

@@ -15,6 +15,7 @@ from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFacto
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
+    ZERO_GRAD_SET_TO_NONE_DEFAULT,
     LRSchedulerBase,
     OptimizerFactory,
     Parameters,
@@ -621,7 +622,10 @@ class Optimizer(torch.optim.Optimizer):
             with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
                 assert len(averaged_gradients) == len(optimized_parameters)
                 for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
-                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+                    if opt_param.grad is None:
+                        opt_param.grad = averaged_grad.clone()
+                    else:
+                        opt_param.grad.copy_(averaged_grad, non_blocking=True)
 
         self.grad_averager.notify_used_averaged_gradients()
 
@@ -634,7 +638,7 @@ class Optimizer(torch.optim.Optimizer):
         # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
         self._load_averaged_gradients_into_optimizer_()
 
-    def zero_grad(self, set_to_none: bool = False):
+    def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT):
         """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
         if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
             raise ValueError(
@@ -643,11 +647,9 @@ class Optimizer(torch.optim.Optimizer):
             )
         for param_group in self.param_groups:
             for param in param_group["params"]:
-                if param.grad is None:
-                    pass
-                elif set_to_none:
+                if set_to_none:
                     param.grad = None
-                else:
+                elif param.grad is not None:
                     param.grad.zero_()
 
     def _should_load_state_from_peers(self) -> bool:

+ 19 - 3
hivemind/optim/state_averager.py

@@ -8,6 +8,7 @@ from itertools import chain
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
+from packaging.version import Version
 
 import hivemind
 from hivemind.averaging import DecentralizedAverager
@@ -22,7 +23,12 @@ logger = get_logger(__name__)
 Parameters = Iterable[torch.Tensor]
 ParamGroups = Iterable[Dict[str, Any]]
 TorchOptimizer = torch.optim.Optimizer
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+if Version(torch.__version__).major >= 2:
+    ZERO_GRAD_SET_TO_NONE_DEFAULT = True
+    LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
+else:
+    ZERO_GRAD_SET_TO_NONE_DEFAULT = False
+    LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
 OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
 SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
 
@@ -332,6 +338,7 @@ class TrainingStateAverager(DecentralizedAverager):
         averaging_control: Optional[StepControl] = None,
         wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
+        set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
         """
@@ -353,6 +360,8 @@ class TrainingStateAverager(DecentralizedAverager):
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
+        :param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
+          (default in PyTorch 2.0+)
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         if delay_averaging is None:
@@ -430,6 +439,7 @@ class TrainingStateAverager(DecentralizedAverager):
                 averaging_round,
                 averaging_control,
                 grad_scaler,
+                set_to_none,
                 **averaging_opts or {},
             )
             self.pending_updates.add(pending_update)
@@ -472,6 +482,7 @@ class TrainingStateAverager(DecentralizedAverager):
         averaging_round: bool,
         averaging_control: Optional[StepControl],
         grad_scaler: Optional[GradScaler],
+        set_to_none: bool,
         timeout: Optional[float] = None,
         **kwargs,
     ):
@@ -515,7 +526,9 @@ class TrainingStateAverager(DecentralizedAverager):
                     self.optimizer.zero_grad()
                     if self.offload_optimizer:
                         for parameter in self.main_parameters:
-                            if parameter.grad is not None:
+                            if set_to_none:
+                                parameter.grad = None
+                            elif parameter.grad is not None:
                                 parameter.grad.zero_()
 
                 self._update_scheduler()
@@ -566,7 +579,10 @@ class TrainingStateAverager(DecentralizedAverager):
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
         for main_param, opt_param in zip(self.main_parameters, opt_parameters):
             if main_param.grad is not None:
-                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+                if opt_param.grad is None:
+                    opt_param.grad = main_param.grad.clone()
+                else:
+                    opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
     @torch.no_grad()
     def _apply_optimizer_parameters_(self):

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 PyYAML
-torch>=1.9.0,<2.0.0
+torch>=1.9.0
 numpy>=1.17
 scipy>=1.2.1
 prefetch_generator>=1.0.1

+ 10 - 4
tests/test_optimizer.py

@@ -15,7 +15,7 @@ from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFacto
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
-from hivemind.optim.state_averager import TrainingStateAverager
+from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 
 
@@ -79,8 +79,11 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
         assert torch.allclose(model2.w.grad, ref_average)
 
     # after no longer use_averaged_gradients
-    assert not torch.allclose(model1.w.grad, ref_average)
-    assert not torch.allclose(model2.w.grad, ref_average)
+    if ZERO_GRAD_SET_TO_NONE_DEFAULT:  # averager1 has reuse_grad_buffers=False
+        assert model1.w.grad is None
+    else:
+        assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)  # averager2 has reuse_grad_buffers=True
 
 
 @pytest.mark.forked
@@ -151,7 +154,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
         F.mse_loss(model2(x), -torch.ones(3)).backward()
         avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
 
-    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    if ZERO_GRAD_SET_TO_NONE_DEFAULT:
+        assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
+    else:
+        assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
     assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
     assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
     assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"