فهرست منبع

Support PyTorch 2.0.0 (#559)

- Fix LRSchedulerBase
- Handle None after .zero_grad() in torch 2.0.0
- Use set_to_none=True by default in torch>=2.0
- Add set_to_none param to TrainingStateAverager.step()

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic 2 سال پیش
والد
کامیت
98531ce9ac
6فایلهای تغییر یافته به همراه44 افزوده شده و 17 حذف شده
  1. 1 2
      examples/albert/run_trainer.py
  2. 5 1
      hivemind/moe/server/module_backend.py
  3. 8 6
      hivemind/optim/optimizer.py
  4. 19 3
      hivemind/optim/state_averager.py
  5. 1 1
      requirements.txt
  6. 10 4
      tests/test_optimizer.py

+ 1 - 2
examples/albert/run_trainer.py

@@ -18,6 +18,7 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 from transformers.trainer_utils import is_main_process
 
 
 from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
 from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
+from hivemind.optim.state_averager import LRSchedulerBase
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.networking import log_visible_maddrs
 from hivemind.utils.networking import log_visible_maddrs
 
 
@@ -33,8 +34,6 @@ from arguments import (
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
-
 
 
 def setup_transformers_logging(process_rank: int):
 def setup_transformers_logging(process_rank: int):
     if is_main_process(process_rank):
     if is_main_process(process_rank):

+ 5 - 1
hivemind/moe/server/module_backend.py

@@ -8,9 +8,13 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
+try:
+    LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
+except AttributeError:  # torch < 2.0.0
+    LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
+
 
 
 class ModuleBackend:
 class ModuleBackend:
     """
     """

+ 8 - 6
hivemind/optim/optimizer.py

@@ -15,6 +15,7 @@ from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFacto
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
 from hivemind.optim.state_averager import (
+    ZERO_GRAD_SET_TO_NONE_DEFAULT,
     LRSchedulerBase,
     LRSchedulerBase,
     OptimizerFactory,
     OptimizerFactory,
     Parameters,
     Parameters,
@@ -621,7 +622,10 @@ class Optimizer(torch.optim.Optimizer):
             with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
             with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
                 assert len(averaged_gradients) == len(optimized_parameters)
                 assert len(averaged_gradients) == len(optimized_parameters)
                 for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
                 for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
-                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+                    if opt_param.grad is None:
+                        opt_param.grad = averaged_grad.clone()
+                    else:
+                        opt_param.grad.copy_(averaged_grad, non_blocking=True)
 
 
         self.grad_averager.notify_used_averaged_gradients()
         self.grad_averager.notify_used_averaged_gradients()
 
 
@@ -634,7 +638,7 @@ class Optimizer(torch.optim.Optimizer):
         # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
         # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
         self._load_averaged_gradients_into_optimizer_()
         self._load_averaged_gradients_into_optimizer_()
 
 
-    def zero_grad(self, set_to_none: bool = False):
+    def zero_grad(self, set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT):
         """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
         """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
         if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
         if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
             raise ValueError(
             raise ValueError(
@@ -643,11 +647,9 @@ class Optimizer(torch.optim.Optimizer):
             )
             )
         for param_group in self.param_groups:
         for param_group in self.param_groups:
             for param in param_group["params"]:
             for param in param_group["params"]:
-                if param.grad is None:
-                    pass
-                elif set_to_none:
+                if set_to_none:
                     param.grad = None
                     param.grad = None
-                else:
+                elif param.grad is not None:
                     param.grad.zero_()
                     param.grad.zero_()
 
 
     def _should_load_state_from_peers(self) -> bool:
     def _should_load_state_from_peers(self) -> bool:

+ 19 - 3
hivemind/optim/state_averager.py

@@ -8,6 +8,7 @@ from itertools import chain
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
+from packaging.version import Version
 
 
 import hivemind
 import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
@@ -22,7 +23,12 @@ logger = get_logger(__name__)
 Parameters = Iterable[torch.Tensor]
 Parameters = Iterable[torch.Tensor]
 ParamGroups = Iterable[Dict[str, Any]]
 ParamGroups = Iterable[Dict[str, Any]]
 TorchOptimizer = torch.optim.Optimizer
 TorchOptimizer = torch.optim.Optimizer
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+if Version(torch.__version__).major >= 2:
+    ZERO_GRAD_SET_TO_NONE_DEFAULT = True
+    LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
+else:
+    ZERO_GRAD_SET_TO_NONE_DEFAULT = False
+    LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
 OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
 OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
 SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
 SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
 
 
@@ -332,6 +338,7 @@ class TrainingStateAverager(DecentralizedAverager):
         averaging_control: Optional[StepControl] = None,
         averaging_control: Optional[StepControl] = None,
         wait_for_trigger: Optional[Callable[[], Any]] = None,
         wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         grad_scaler: Optional[GradScaler] = None,
+        set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
         averaging_opts: Optional[Dict[str, Any]] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
     ):
         """
         """
@@ -353,6 +360,8 @@ class TrainingStateAverager(DecentralizedAverager):
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
+        :param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
+          (default in PyTorch 2.0+)
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         """
         if delay_averaging is None:
         if delay_averaging is None:
@@ -430,6 +439,7 @@ class TrainingStateAverager(DecentralizedAverager):
                 averaging_round,
                 averaging_round,
                 averaging_control,
                 averaging_control,
                 grad_scaler,
                 grad_scaler,
+                set_to_none,
                 **averaging_opts or {},
                 **averaging_opts or {},
             )
             )
             self.pending_updates.add(pending_update)
             self.pending_updates.add(pending_update)
@@ -472,6 +482,7 @@ class TrainingStateAverager(DecentralizedAverager):
         averaging_round: bool,
         averaging_round: bool,
         averaging_control: Optional[StepControl],
         averaging_control: Optional[StepControl],
         grad_scaler: Optional[GradScaler],
         grad_scaler: Optional[GradScaler],
+        set_to_none: bool,
         timeout: Optional[float] = None,
         timeout: Optional[float] = None,
         **kwargs,
         **kwargs,
     ):
     ):
@@ -515,7 +526,9 @@ class TrainingStateAverager(DecentralizedAverager):
                     self.optimizer.zero_grad()
                     self.optimizer.zero_grad()
                     if self.offload_optimizer:
                     if self.offload_optimizer:
                         for parameter in self.main_parameters:
                         for parameter in self.main_parameters:
-                            if parameter.grad is not None:
+                            if set_to_none:
+                                parameter.grad = None
+                            elif parameter.grad is not None:
                                 parameter.grad.zero_()
                                 parameter.grad.zero_()
 
 
                 self._update_scheduler()
                 self._update_scheduler()
@@ -566,7 +579,10 @@ class TrainingStateAverager(DecentralizedAverager):
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
         for main_param, opt_param in zip(self.main_parameters, opt_parameters):
         for main_param, opt_param in zip(self.main_parameters, opt_parameters):
             if main_param.grad is not None:
             if main_param.grad is not None:
-                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+                if opt_param.grad is None:
+                    opt_param.grad = main_param.grad.clone()
+                else:
+                    opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
 
     @torch.no_grad()
     @torch.no_grad()
     def _apply_optimizer_parameters_(self):
     def _apply_optimizer_parameters_(self):

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 PyYAML
 PyYAML
-torch>=1.9.0,<2.0.0
+torch>=1.9.0
 numpy>=1.17
 numpy>=1.17
 scipy>=1.2.1
 scipy>=1.2.1
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1

+ 10 - 4
tests/test_optimizer.py

@@ -15,7 +15,7 @@ from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFacto
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.progress_tracker import ProgressTracker
-from hivemind.optim.state_averager import TrainingStateAverager
+from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 
 
@@ -79,8 +79,11 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
         assert torch.allclose(model2.w.grad, ref_average)
         assert torch.allclose(model2.w.grad, ref_average)
 
 
     # after no longer use_averaged_gradients
     # after no longer use_averaged_gradients
-    assert not torch.allclose(model1.w.grad, ref_average)
-    assert not torch.allclose(model2.w.grad, ref_average)
+    if ZERO_GRAD_SET_TO_NONE_DEFAULT:  # averager1 has reuse_grad_buffers=False
+        assert model1.w.grad is None
+    else:
+        assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)  # averager2 has reuse_grad_buffers=True
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -151,7 +154,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
         F.mse_loss(model2(x), -torch.ones(3)).backward()
         F.mse_loss(model2(x), -torch.ones(3)).backward()
         avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
         avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
 
 
-    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    if ZERO_GRAD_SET_TO_NONE_DEFAULT:
+        assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
+    else:
+        assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
     assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
     assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
     assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
     assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
     assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
     assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"