|
@@ -8,6 +8,7 @@ from itertools import chain
|
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
+from packaging.version import Version
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind.averaging import DecentralizedAverager
|
|
@@ -22,7 +23,12 @@ logger = get_logger(__name__)
|
|
|
Parameters = Iterable[torch.Tensor]
|
|
|
ParamGroups = Iterable[Dict[str, Any]]
|
|
|
TorchOptimizer = torch.optim.Optimizer
|
|
|
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
|
|
|
+if Version(torch.__version__).major >= 2:
|
|
|
+ ZERO_GRAD_SET_TO_NONE_DEFAULT = True
|
|
|
+ LRSchedulerBase = torch.optim.lr_scheduler.LRScheduler
|
|
|
+else:
|
|
|
+ ZERO_GRAD_SET_TO_NONE_DEFAULT = False
|
|
|
+ LRSchedulerBase = torch.optim.lr_scheduler._LRScheduler
|
|
|
OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
|
|
|
SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
|
|
|
|
|
@@ -332,6 +338,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
averaging_control: Optional[StepControl] = None,
|
|
|
wait_for_trigger: Optional[Callable[[], Any]] = None,
|
|
|
grad_scaler: Optional[GradScaler] = None,
|
|
|
+ set_to_none: bool = ZERO_GRAD_SET_TO_NONE_DEFAULT,
|
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
|
):
|
|
|
"""
|
|
@@ -353,6 +360,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
:param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
|
|
|
:note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
|
|
|
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
|
+ :param set_to_none: if True, zero_grad sets local gradients to None instead of zero tensors
|
|
|
+ (default in PyTorch 2.0+)
|
|
|
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
|
|
|
"""
|
|
|
if delay_averaging is None:
|
|
@@ -430,6 +439,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
averaging_round,
|
|
|
averaging_control,
|
|
|
grad_scaler,
|
|
|
+ set_to_none,
|
|
|
**averaging_opts or {},
|
|
|
)
|
|
|
self.pending_updates.add(pending_update)
|
|
@@ -472,6 +482,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
averaging_round: bool,
|
|
|
averaging_control: Optional[StepControl],
|
|
|
grad_scaler: Optional[GradScaler],
|
|
|
+ set_to_none: bool,
|
|
|
timeout: Optional[float] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
@@ -515,7 +526,9 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
self.optimizer.zero_grad()
|
|
|
if self.offload_optimizer:
|
|
|
for parameter in self.main_parameters:
|
|
|
- if parameter.grad is not None:
|
|
|
+ if set_to_none:
|
|
|
+ parameter.grad = None
|
|
|
+ elif parameter.grad is not None:
|
|
|
parameter.grad.zero_()
|
|
|
|
|
|
self._update_scheduler()
|
|
@@ -566,7 +579,10 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
for main_param, opt_param in zip(self.main_parameters, opt_parameters):
|
|
|
if main_param.grad is not None:
|
|
|
- opt_param.grad.copy_(main_param.grad, non_blocking=True)
|
|
|
+ if opt_param.grad is None:
|
|
|
+ opt_param.grad = main_param.grad.clone()
|
|
|
+ else:
|
|
|
+ opt_param.grad.copy_(main_param.grad, non_blocking=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _apply_optimizer_parameters_(self):
|