|
@@ -9,10 +9,10 @@ from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
|
|
|
import torch
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind import nested_compare
|
|
|
from hivemind.averaging import DecentralizedAverager
|
|
|
from hivemind.compression import CompressionInfo, TensorRole
|
|
|
-from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
|
|
|
+from hivemind.optim.grad_scaler import GradScaler
|
|
|
+from hivemind.utils import get_logger, nested_flatten, nested_pack
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -100,7 +100,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
self.offload_optimizer = offload_optimizer
|
|
|
self.custom_gradients = custom_gradients
|
|
|
|
|
|
- self._main_parameters, self._parameter_names = main_parameters, parameter_names
|
|
|
+ self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
|
self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
|
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
|
param_groups, optimizer, scheduler, initialize_optimizer
|
|
@@ -197,7 +197,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
|
|
|
logger.log(
|
|
|
self.status_loglevel,
|
|
|
- "Initializing optimizer manually since it has no tensors in state dict"
|
|
|
+ "Initializing optimizer manually since it has no tensors in state dict. "
|
|
|
"To override this, please provide initialize_optimizer=False",
|
|
|
)
|
|
|
|
|
@@ -257,12 +257,12 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
|
|
|
"""Get CompressionInfo for each state tensor, accounting for its role and specification"""
|
|
|
tensor_infos = []
|
|
|
- for param, param_name in zip(self._main_parameters, self._parameter_names):
|
|
|
+ for param, param_name in zip(self.main_parameters, self.parameter_names):
|
|
|
tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
|
|
|
for stats_name in self.opt_keys_for_averaging:
|
|
|
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
- assert len(opt_parameters) == len(self._parameter_names)
|
|
|
- for param, param_name in zip(opt_parameters, self._parameter_names):
|
|
|
+ assert len(opt_parameters) == len(self.parameter_names)
|
|
|
+ for param, param_name in zip(opt_parameters, self.parameter_names):
|
|
|
tensor_infos.append(
|
|
|
CompressionInfo.from_tensor(
|
|
|
self.optimizer.state[param][stats_name],
|
|
@@ -284,7 +284,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
delay_optimizer_step: bool = False,
|
|
|
averaging_round: bool = False,
|
|
|
delay_averaging: Optional[bool] = None,
|
|
|
- averaging_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
+ grad_scaler: Optional[GradScaler] = None,
|
|
|
+ averaging_opts: Optional[Dict[str, Any]] = None,
|
|
|
):
|
|
|
"""
|
|
|
Perform one or several possible actions, depending on the specified keyword args.
|
|
@@ -298,9 +299,10 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
:param zero_grad: if True, reset local gradients after performing optimizer step
|
|
|
:param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
|
|
|
:param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
|
|
|
+ :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
|
:param delay_averaging: if True, perform averaging in background and apply results in a future step
|
|
|
by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
|
|
|
- :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
|
|
|
+ :param averaging_opts: a dict of keyword arguments forwarded into averaging round
|
|
|
"""
|
|
|
if delay_averaging is None:
|
|
|
delay_averaging = delay_optimizer_step
|
|
@@ -312,8 +314,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
if delay_optimizer_step:
|
|
|
assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
|
|
|
assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
|
|
|
- if averaging_kwargs and not averaging_round:
|
|
|
- logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
|
|
|
+ if averaging_opts and not averaging_round:
|
|
|
+ logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
|
|
|
output = None
|
|
|
|
|
|
if wait_for_delayed_update:
|
|
@@ -328,19 +330,17 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
if self.finished_averaging_round.is_set():
|
|
|
if not self.reuse_tensors:
|
|
|
self._apply_averaging_results_()
|
|
|
- logger.log(self.status_loglevel, "Received results from background averaging round")
|
|
|
+ logger.log(self.status_loglevel, "Received parameters from background averaging round")
|
|
|
self.finished_averaging_round.clear()
|
|
|
|
|
|
if self.finished_optimizer_step.is_set():
|
|
|
if self.offload_optimizer:
|
|
|
self._apply_optimizer_results_()
|
|
|
- logger.log(self.status_loglevel, "Received results from background optimizer step")
|
|
|
+ logger.log(self.status_loglevel, "Received parameters from background optimizer step")
|
|
|
self.finished_optimizer_step.clear()
|
|
|
|
|
|
if increment_epoch:
|
|
|
self.local_epoch += 1
|
|
|
- logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
|
|
|
- self._update_scheduler()
|
|
|
|
|
|
if optimizer_step or zero_grad or averaging_round:
|
|
|
assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
|
|
@@ -353,7 +353,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
optimizer_step,
|
|
|
zero_grad,
|
|
|
averaging_round,
|
|
|
- **averaging_kwargs or {},
|
|
|
+ grad_scaler,
|
|
|
+ **averaging_opts or {},
|
|
|
)
|
|
|
|
|
|
if (optimizer_step or zero_grad) and not delay_optimizer_step:
|
|
@@ -378,7 +379,9 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
self.finished_optimizer_step.clear()
|
|
|
return output
|
|
|
|
|
|
- def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
|
|
|
+ def _do(
|
|
|
+ self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
|
|
|
+ ):
|
|
|
"""
|
|
|
Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
|
|
|
This method is meant to be called in the background executor.
|
|
@@ -386,12 +389,23 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
try:
|
|
|
if optimizer_step:
|
|
|
logger.log(self.status_loglevel, f"Running optimizer step")
|
|
|
- self.optimizer.step()
|
|
|
+ if grad_scaler is None:
|
|
|
+ self.optimizer.step()
|
|
|
+ else:
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
+ assert grad_scaler.step(self.optimizer)
|
|
|
+
|
|
|
+ if grad_scaler is not None:
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
+ assert grad_scaler.update()
|
|
|
+
|
|
|
+ self._update_scheduler()
|
|
|
+
|
|
|
if zero_grad:
|
|
|
logger.log(self.status_loglevel, f"Running zero grad")
|
|
|
self.optimizer.zero_grad()
|
|
|
if self.offload_optimizer:
|
|
|
- for parameter in self._main_parameters:
|
|
|
+ for parameter in self.main_parameters:
|
|
|
if parameter.grad is not None:
|
|
|
parameter.grad.zero_()
|
|
|
|
|
@@ -428,7 +442,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
"""Copy local gradients into the gradient buffers of the offloaded optimizer"""
|
|
|
assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
|
|
|
opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
- for main_param, opt_param in zip(self._main_parameters, opt_parameters):
|
|
|
+ for main_param, opt_param in zip(self.main_parameters, opt_parameters):
|
|
|
if main_param.grad is not None:
|
|
|
opt_param.grad.copy_(main_param.grad, non_blocking=True)
|
|
|
|
|
@@ -438,8 +452,10 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
|
|
|
with self.lock_averaged_tensors:
|
|
|
offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
- assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
|
|
|
- for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
|
|
|
+ assert len(offloaded_parameters) == len(
|
|
|
+ self.main_parameters
|
|
|
+ ), "Optimizer parameters changed during training"
|
|
|
+ for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
|
|
|
main_param.copy_(offloaded_param, non_blocking=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
@@ -471,7 +487,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
)
|
|
|
parameter_infos = [
|
|
|
CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
|
|
|
- for param, key in zip(optimized_parameters, self._parameter_names)
|
|
|
+ for param, key in zip(optimized_parameters, self.parameter_names)
|
|
|
]
|
|
|
extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
|
|
|
extra_infos = [
|
|
@@ -496,7 +512,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
|
|
|
:returns: whether or the averager succeeded in loading parameters
|
|
|
"""
|
|
|
- parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
|
|
|
+ parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
|
|
|
num_parameters_and_extras = len(parameters_and_extras)
|
|
|
|
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|