|
@@ -1,18 +1,20 @@
|
|
""" An extension of averager that supports common optimization use cases. """
|
|
""" An extension of averager that supports common optimization use cases. """
|
|
import logging
|
|
import logging
|
|
-from asyncio import Future
|
|
|
|
|
|
+import threading
|
|
|
|
+import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
+from contextlib import nullcontext
|
|
from itertools import chain
|
|
from itertools import chain
|
|
-from threading import Event
|
|
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
|
|
|
import hivemind
|
|
import hivemind
|
|
from hivemind.averaging import DecentralizedAverager
|
|
from hivemind.averaging import DecentralizedAverager
|
|
|
|
+from hivemind.averaging.control import StepControl
|
|
from hivemind.compression import CompressionInfo, TensorRole
|
|
from hivemind.compression import CompressionInfo, TensorRole
|
|
from hivemind.optim.grad_scaler import GradScaler
|
|
from hivemind.optim.grad_scaler import GradScaler
|
|
-from hivemind.utils import get_logger, nested_flatten, nested_pack
|
|
|
|
|
|
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -36,7 +38,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
Example:
|
|
Example:
|
|
|
|
|
|
- >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
|
|
|
|
|
|
+ >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
|
|
>>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
|
|
>>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
|
|
>>> avgr.load_state_from_peers()
|
|
>>> avgr.load_state_from_peers()
|
|
>>> for i, batch in enumerate(training_dataloader):
|
|
>>> for i, batch in enumerate(training_dataloader):
|
|
@@ -49,7 +51,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
TrainingStateAverager.step(..., optimizer_step=True)
|
|
TrainingStateAverager.step(..., optimizer_step=True)
|
|
|
|
|
|
:param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
|
|
:param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
|
|
- :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
|
|
|
|
|
|
+ :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
|
|
:param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
|
|
:param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
|
|
:note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
|
|
:note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
|
|
:param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
|
|
:param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
|
|
@@ -60,8 +62,11 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
|
|
This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
|
|
:param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
|
|
:param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
|
|
For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
|
|
For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
|
|
|
|
+ Defaults to True if offload_optimizer, False otherwise.
|
|
|
|
+ :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
|
|
|
|
+ while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
|
|
:param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
|
|
:param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
|
|
- :param parameter_names: optionally provide parameter names in the same order as param_groups
|
|
|
|
|
|
+ :param parameter_names: optionally provide parameter names in the same order as in params
|
|
:param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
|
|
:param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
|
|
:param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
|
|
:param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
|
|
:note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
|
|
:note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
|
|
@@ -73,12 +78,14 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
*,
|
|
*,
|
|
dht: hivemind.DHT,
|
|
dht: hivemind.DHT,
|
|
optimizer: Union[TorchOptimizer, OptimizerFactory],
|
|
optimizer: Union[TorchOptimizer, OptimizerFactory],
|
|
- param_groups: Optional[Union[Parameters, ParamGroups]] = None,
|
|
|
|
|
|
+ params: Optional[Union[Parameters, ParamGroups]] = None,
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
initialize_optimizer: Optional[bool] = None,
|
|
initialize_optimizer: Optional[bool] = None,
|
|
offload_optimizer: bool = False,
|
|
offload_optimizer: bool = False,
|
|
custom_gradients: bool = False,
|
|
custom_gradients: bool = False,
|
|
- reuse_tensors: bool = False,
|
|
|
|
|
|
+ reuse_tensors: Optional[bool] = None,
|
|
|
|
+ delta_rule_averaging: bool = False,
|
|
|
|
+ performance_ema_alpha: float = 0.1,
|
|
sync_epoch_when_averaging: bool = False,
|
|
sync_epoch_when_averaging: bool = False,
|
|
parameter_names: Optional[Sequence[str]] = None,
|
|
parameter_names: Optional[Sequence[str]] = None,
|
|
average_opt_statistics: Sequence[str] = (),
|
|
average_opt_statistics: Sequence[str] = (),
|
|
@@ -88,20 +95,22 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
):
|
|
):
|
|
average_opt_statistics = tuple(average_opt_statistics)
|
|
average_opt_statistics = tuple(average_opt_statistics)
|
|
assert all(isinstance(key, str) for key in average_opt_statistics)
|
|
assert all(isinstance(key, str) for key in average_opt_statistics)
|
|
- if offload_optimizer and reuse_tensors:
|
|
|
|
- logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
|
|
|
|
|
|
+ if reuse_tensors is None:
|
|
|
|
+ reuse_tensors = offload_optimizer and not delta_rule_averaging
|
|
if custom_gradients and not offload_optimizer:
|
|
if custom_gradients and not offload_optimizer:
|
|
logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
|
|
logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
|
|
|
|
+ if reuse_tensors and delta_rule_averaging:
|
|
|
|
+ raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
|
|
|
|
|
|
- param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
|
|
|
|
|
|
+ param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
|
|
|
|
|
|
self.status_loglevel = status_loglevel
|
|
self.status_loglevel = status_loglevel
|
|
- self.reuse_tensors = reuse_tensors
|
|
|
|
- self.offload_optimizer = offload_optimizer
|
|
|
|
- self.custom_gradients = custom_gradients
|
|
|
|
|
|
+ self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
|
|
|
|
+ self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
|
|
|
|
+ self._old_tensors: Optional[Sequence[torch.Tensor]] = None # for delta rule
|
|
|
|
|
|
self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
- self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
|
|
|
|
|
|
+ self._averaged_parameters = self._make_averaged_parameters(main_parameters)
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
param_groups, optimizer, scheduler, initialize_optimizer
|
|
param_groups, optimizer, scheduler, initialize_optimizer
|
|
)
|
|
)
|
|
@@ -109,11 +118,13 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
self.sync_epoch_when_averaging = sync_epoch_when_averaging
|
|
self.sync_epoch_when_averaging = sync_epoch_when_averaging
|
|
self.local_epoch = 0
|
|
self.local_epoch = 0
|
|
|
|
|
|
- self.step_executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
- self.finished_optimizer_step = Event()
|
|
|
|
- self.finished_averaging_round = Event()
|
|
|
|
- self.pending_update = Future()
|
|
|
|
- self.pending_update.set_result(None)
|
|
|
|
|
|
+ self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
|
|
|
|
+ self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
|
|
|
|
+ self.finished_optimizer_step = threading.Event()
|
|
|
|
+ self.finished_averaging_round = threading.Event()
|
|
|
|
+ self.lock_optimizer = threading.Lock()
|
|
|
|
+ self.lock_averaging = threading.Lock()
|
|
|
|
+ self.pending_updates = set()
|
|
|
|
|
|
super().__init__(
|
|
super().__init__(
|
|
dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
|
|
dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
|
|
@@ -143,10 +154,15 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
|
|
assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
|
|
return param_groups, parameters, parameter_names
|
|
return param_groups, parameters, parameter_names
|
|
|
|
|
|
- def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
+ def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
|
|
|
|
+ """Initialize averaged parameters based on the optimizer and averaging mode"""
|
|
|
|
+ return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
|
|
|
|
+
|
|
|
|
+ def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
|
|
"""Create a new tensor for averaging or reuse the existing one"""
|
|
"""Create a new tensor for averaging or reuse the existing one"""
|
|
- if self.reuse_tensors:
|
|
|
|
- assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
|
|
|
|
|
|
+ if self.reuse_tensors and not force_copy:
|
|
|
|
+ if source_tensor.device != torch.device("cpu"):
|
|
|
|
+ raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
|
|
if not source_tensor.is_shared():
|
|
if not source_tensor.is_shared():
|
|
source_tensor.share_memory_()
|
|
source_tensor.share_memory_()
|
|
return source_tensor
|
|
return source_tensor
|
|
@@ -173,19 +189,26 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
# create optimizer
|
|
# create optimizer
|
|
if optimizer_is_factory:
|
|
if optimizer_is_factory:
|
|
if self.offload_optimizer:
|
|
if self.offload_optimizer:
|
|
- for param in self._averaged_parameters:
|
|
|
|
- if param.grad is None:
|
|
|
|
- param.grad = torch.zeros_like(param)
|
|
|
|
|
|
+ if self.reuse_tensors:
|
|
|
|
+ parameters_for_optimizer = self._averaged_parameters
|
|
|
|
+ else:
|
|
|
|
+ parameters_for_optimizer = tuple(
|
|
|
|
+ tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
|
|
|
+ for tensor in self._averaged_parameters
|
|
|
|
+ )
|
|
|
|
|
|
next_index = 0
|
|
next_index = 0
|
|
param_groups_for_optimizer = []
|
|
param_groups_for_optimizer = []
|
|
for param_group in param_groups:
|
|
for param_group in param_groups:
|
|
num_params = len(param_group["params"])
|
|
num_params = len(param_group["params"])
|
|
- averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
|
|
|
|
|
|
+ averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
|
|
param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
|
|
param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
|
|
next_index += num_params
|
|
next_index += num_params
|
|
- assert next_index == len(self._averaged_parameters)
|
|
|
|
|
|
+ assert next_index == len(parameters_for_optimizer)
|
|
|
|
|
|
|
|
+ for param in parameters_for_optimizer:
|
|
|
|
+ if param.grad is None:
|
|
|
|
+ param.grad = torch.zeros_like(param)
|
|
else:
|
|
else:
|
|
param_groups_for_optimizer = param_groups
|
|
param_groups_for_optimizer = param_groups
|
|
optimizer = optimizer_or_factory(param_groups_for_optimizer)
|
|
optimizer = optimizer_or_factory(param_groups_for_optimizer)
|
|
@@ -198,7 +221,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
logger.log(
|
|
logger.log(
|
|
self.status_loglevel,
|
|
self.status_loglevel,
|
|
"Initializing optimizer manually since it has no tensors in state dict. "
|
|
"Initializing optimizer manually since it has no tensors in state dict. "
|
|
- "To override this, please provide initialize_optimizer=False",
|
|
|
|
|
|
+ "To override this, provide initialize_optimizer=False",
|
|
)
|
|
)
|
|
|
|
|
|
if initialize_optimizer:
|
|
if initialize_optimizer:
|
|
@@ -213,7 +236,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
# verify optimizer and scheduler
|
|
# verify optimizer and scheduler
|
|
assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
|
|
assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
|
|
- if self.offload_optimizer or self.reuse_tensors:
|
|
|
|
|
|
+ if self.reuse_tensors:
|
|
for param_group in optimizer.param_groups:
|
|
for param_group in optimizer.param_groups:
|
|
for param in param_group["params"]:
|
|
for param in param_group["params"]:
|
|
assert param.is_shared()
|
|
assert param.is_shared()
|
|
@@ -250,7 +273,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
|
|
for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
|
|
assert local_tensor.shape == averaged_tensor.shape
|
|
assert local_tensor.shape == averaged_tensor.shape
|
|
if averaged_tensor.grad is not None:
|
|
if averaged_tensor.grad is not None:
|
|
- logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
|
|
|
|
|
|
+ logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
|
|
|
|
|
|
return averaged_tensors
|
|
return averaged_tensors
|
|
|
|
|
|
@@ -274,9 +297,22 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
|
|
tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
|
|
return tuple(tensor_infos)
|
|
return tuple(tensor_infos)
|
|
|
|
|
|
|
|
+ def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
|
|
|
|
+ """
|
|
|
|
+ Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
|
|
|
|
+
|
|
|
|
+ :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
|
|
|
|
+ :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
|
|
|
|
+ :note: setting weight at this stage is not supported, please leave this parameter as None
|
|
|
|
+ :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
|
|
|
|
+ :note: in the current implementation, each step_control can only be used in one step.
|
|
|
|
+ """
|
|
|
|
+ assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
|
|
|
|
+ return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
|
|
|
|
+
|
|
def step(
|
|
def step(
|
|
self,
|
|
self,
|
|
- wait_for_delayed_update: bool = None,
|
|
|
|
|
|
+ wait_for_delayed_updates: bool = None,
|
|
apply_delayed_updates: bool = True,
|
|
apply_delayed_updates: bool = True,
|
|
increment_epoch: bool = False,
|
|
increment_epoch: bool = False,
|
|
optimizer_step: bool = False,
|
|
optimizer_step: bool = False,
|
|
@@ -284,6 +320,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
delay_optimizer_step: bool = False,
|
|
delay_optimizer_step: bool = False,
|
|
averaging_round: bool = False,
|
|
averaging_round: bool = False,
|
|
delay_averaging: Optional[bool] = None,
|
|
delay_averaging: Optional[bool] = None,
|
|
|
|
+ averaging_control: Optional[StepControl] = None,
|
|
|
|
+ wait_for_trigger: Optional[Callable[[], Any]] = None,
|
|
grad_scaler: Optional[GradScaler] = None,
|
|
grad_scaler: Optional[GradScaler] = None,
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
):
|
|
):
|
|
@@ -291,138 +329,205 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
Perform one or several possible actions, depending on the specified keyword args.
|
|
Perform one or several possible actions, depending on the specified keyword args.
|
|
The actions will be performed in the same order as specified below:
|
|
The actions will be performed in the same order as specified below:
|
|
|
|
|
|
- :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
|
|
|
|
|
|
+ :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
|
|
by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
|
|
by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
|
|
:param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
|
|
:param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
|
|
:param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
|
|
:param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
|
|
|
|
+ :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
|
|
:param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
|
|
:param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
|
|
:param zero_grad: if True, reset local gradients after performing optimizer step
|
|
:param zero_grad: if True, reset local gradients after performing optimizer step
|
|
:param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
|
|
:param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
|
|
:param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
|
|
:param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
|
|
- :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
|
|
:param delay_averaging: if True, perform averaging in background and apply results in a future step
|
|
:param delay_averaging: if True, perform averaging in background and apply results in a future step
|
|
by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
|
|
by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
|
|
|
|
+ :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
|
|
|
|
+ :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
|
|
|
|
+ :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
|
|
|
|
+ :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
|
|
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
|
|
"""
|
|
"""
|
|
if delay_averaging is None:
|
|
if delay_averaging is None:
|
|
delay_averaging = delay_optimizer_step
|
|
delay_averaging = delay_optimizer_step
|
|
- if wait_for_delayed_update is None:
|
|
|
|
- wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
|
|
|
|
|
|
+ should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
|
|
|
|
+ if wait_for_delayed_updates is None:
|
|
|
|
+ wait_for_delayed_updates = should_wait
|
|
|
|
+ if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
|
|
|
|
+ raise ValueError("Should wait for background operation to finish before scheduling new one")
|
|
assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
|
|
assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
|
|
- if optimizer_step or averaging_round or zero_grad:
|
|
|
|
- assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
|
|
|
|
if delay_optimizer_step:
|
|
if delay_optimizer_step:
|
|
assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
|
|
assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
|
|
assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
|
|
assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
|
|
if averaging_opts and not averaging_round:
|
|
if averaging_opts and not averaging_round:
|
|
logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
|
|
logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
|
|
|
|
+ if averaging_control is not None:
|
|
|
|
+ assert averaging_round, "averaging_control is unused if averaging_round is not performed"
|
|
|
|
+ if wait_for_trigger is not None:
|
|
|
|
+ assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
|
|
|
|
+ if not (self.reuse_tensors or self.custom_gradients):
|
|
|
|
+ # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
|
|
|
|
+ # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
|
|
|
|
+ # the possibly different gradients when wait_for_trigger has finished).
|
|
|
|
+ raise ValueError(
|
|
|
|
+ "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
|
|
|
|
+ "If you know what you're doing, please refer to the comments in the source code for details."
|
|
|
|
+ )
|
|
output = None
|
|
output = None
|
|
|
|
|
|
- if wait_for_delayed_update:
|
|
|
|
- if not self.pending_update.done():
|
|
|
|
- logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
|
|
|
|
- output = self.pending_update.result()
|
|
|
|
-
|
|
|
|
- if self.pending_update.done() and self.pending_update.exception():
|
|
|
|
- logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
|
|
|
|
|
|
+ if wait_for_delayed_updates:
|
|
|
|
+ for pending_update in self.pending_updates:
|
|
|
|
+ try:
|
|
|
|
+ logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
|
|
|
|
+ output = pending_update.result()
|
|
|
|
+ except BaseException:
|
|
|
|
+ pass # exception will be reported below
|
|
|
|
+
|
|
|
|
+ # remove finished updates, log any exceptions
|
|
|
|
+ finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
|
|
|
|
+ self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
|
|
|
|
+ for finished_update in finished_updates:
|
|
|
|
+ if finished_update.exception():
|
|
|
|
+ logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
|
|
|
|
|
|
if apply_delayed_updates:
|
|
if apply_delayed_updates:
|
|
if self.finished_averaging_round.is_set():
|
|
if self.finished_averaging_round.is_set():
|
|
if not self.reuse_tensors:
|
|
if not self.reuse_tensors:
|
|
self._apply_averaging_results_()
|
|
self._apply_averaging_results_()
|
|
|
|
+ if self.offload_optimizer and not self.finished_optimizer_step.is_set():
|
|
|
|
+ self._apply_optimizer_parameters_()
|
|
logger.log(self.status_loglevel, "Received parameters from background averaging round")
|
|
logger.log(self.status_loglevel, "Received parameters from background averaging round")
|
|
self.finished_averaging_round.clear()
|
|
self.finished_averaging_round.clear()
|
|
|
|
|
|
if self.finished_optimizer_step.is_set():
|
|
if self.finished_optimizer_step.is_set():
|
|
if self.offload_optimizer:
|
|
if self.offload_optimizer:
|
|
- self._apply_optimizer_results_()
|
|
|
|
- logger.log(self.status_loglevel, "Received parameters from background optimizer step")
|
|
|
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
|
+ logger.debug("Received parameters from background optimizer step")
|
|
self.finished_optimizer_step.clear()
|
|
self.finished_optimizer_step.clear()
|
|
|
|
|
|
if increment_epoch:
|
|
if increment_epoch:
|
|
self.local_epoch += 1
|
|
self.local_epoch += 1
|
|
|
|
|
|
if optimizer_step or zero_grad or averaging_round:
|
|
if optimizer_step or zero_grad or averaging_round:
|
|
- assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
|
|
|
|
-
|
|
|
|
if self.offload_optimizer and not self.custom_gradients:
|
|
if self.offload_optimizer and not self.custom_gradients:
|
|
self._load_local_grads_into_optimizer_()
|
|
self._load_local_grads_into_optimizer_()
|
|
|
|
|
|
- self.pending_update = self.step_executor.submit(
|
|
|
|
|
|
+ pending_update = self.step_executor.submit(
|
|
self._do,
|
|
self._do,
|
|
|
|
+ wait_for_trigger,
|
|
optimizer_step,
|
|
optimizer_step,
|
|
zero_grad,
|
|
zero_grad,
|
|
averaging_round,
|
|
averaging_round,
|
|
|
|
+ averaging_control,
|
|
grad_scaler,
|
|
grad_scaler,
|
|
**averaging_opts or {},
|
|
**averaging_opts or {},
|
|
)
|
|
)
|
|
|
|
+ self.pending_updates.add(pending_update)
|
|
|
|
+
|
|
|
|
+ should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
|
|
|
|
+ should_await_averaging = averaging_round and not delay_averaging
|
|
|
|
|
|
- if (optimizer_step or zero_grad) and not delay_optimizer_step:
|
|
|
|
|
|
+ if should_await_optimizer:
|
|
self.finished_optimizer_step.wait()
|
|
self.finished_optimizer_step.wait()
|
|
self.finished_optimizer_step.clear()
|
|
self.finished_optimizer_step.clear()
|
|
- if self.offload_optimizer:
|
|
|
|
- self._apply_optimizer_results_()
|
|
|
|
- logger.log(self.status_loglevel, "Finished optimizer step")
|
|
|
|
|
|
+ if self.offload_optimizer and not should_await_averaging:
|
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
|
+ logger.debug("Finished optimizer step")
|
|
|
|
|
|
- if averaging_round and not delay_averaging:
|
|
|
|
|
|
+ if should_await_averaging:
|
|
self.finished_averaging_round.wait()
|
|
self.finished_averaging_round.wait()
|
|
self.finished_averaging_round.clear()
|
|
self.finished_averaging_round.clear()
|
|
if not self.reuse_tensors:
|
|
if not self.reuse_tensors:
|
|
self._apply_averaging_results_()
|
|
self._apply_averaging_results_()
|
|
|
|
+ if self.offload_optimizer:
|
|
|
|
+ self._apply_optimizer_parameters_()
|
|
logger.log(self.status_loglevel, "Finished averaging round")
|
|
logger.log(self.status_loglevel, "Finished averaging round")
|
|
|
|
|
|
- if not delay_averaging:
|
|
|
|
|
|
+ async_averaging = averaging_round and delay_averaging
|
|
|
|
+ async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
|
|
|
|
+
|
|
|
|
+ if not (async_averaging or async_optimizer):
|
|
try:
|
|
try:
|
|
- output = self.pending_update.result()
|
|
|
|
|
|
+ output = pending_update.result()
|
|
finally:
|
|
finally:
|
|
- self.finished_averaging_round.clear()
|
|
|
|
- self.finished_optimizer_step.clear()
|
|
|
|
|
|
+ self.pending_updates.remove(pending_update)
|
|
|
|
+
|
|
return output
|
|
return output
|
|
|
|
|
|
def _do(
|
|
def _do(
|
|
- self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
|
|
|
|
|
|
+ self,
|
|
|
|
+ wait_for_trigger: Optional[Callable[[], Any]],
|
|
|
|
+ optimizer_step: bool,
|
|
|
|
+ zero_grad: bool,
|
|
|
|
+ averaging_round: bool,
|
|
|
|
+ averaging_control: Optional[StepControl],
|
|
|
|
+ grad_scaler: Optional[GradScaler],
|
|
|
|
+ timeout: Optional[float] = None,
|
|
|
|
+ **kwargs,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
|
|
Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
|
|
This method is meant to be called in the background executor.
|
|
This method is meant to be called in the background executor.
|
|
"""
|
|
"""
|
|
- try:
|
|
|
|
- if optimizer_step:
|
|
|
|
- logger.log(self.status_loglevel, f"Running optimizer step")
|
|
|
|
- if grad_scaler is None:
|
|
|
|
- self.optimizer.step()
|
|
|
|
- else:
|
|
|
|
- with grad_scaler.running_global_step():
|
|
|
|
- assert grad_scaler.step(self.optimizer)
|
|
|
|
|
|
+ if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
|
|
|
|
+ logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
|
|
|
|
+ averaging_control = None
|
|
|
|
|
|
- if grad_scaler is not None:
|
|
|
|
- with grad_scaler.running_global_step():
|
|
|
|
- assert grad_scaler.update()
|
|
|
|
|
|
+ start_time = time.perf_counter()
|
|
|
|
+ began_running = False
|
|
|
|
|
|
- self._update_scheduler()
|
|
|
|
-
|
|
|
|
- if zero_grad:
|
|
|
|
- logger.log(self.status_loglevel, f"Running zero grad")
|
|
|
|
- self.optimizer.zero_grad()
|
|
|
|
- if self.offload_optimizer:
|
|
|
|
- for parameter in self.main_parameters:
|
|
|
|
- if parameter.grad is not None:
|
|
|
|
- parameter.grad.zero_()
|
|
|
|
|
|
+ try:
|
|
|
|
+ if averaging_round and averaging_control is None:
|
|
|
|
+ averaging_control = super().step(
|
|
|
|
+ gather=self.local_epoch,
|
|
|
|
+ require_trigger=True,
|
|
|
|
+ timeout=timeout,
|
|
|
|
+ wait=False,
|
|
|
|
+ **kwargs,
|
|
|
|
+ )
|
|
|
|
|
|
- self.finished_optimizer_step.set()
|
|
|
|
|
|
+ if wait_for_trigger is not None:
|
|
|
|
+ wait_for_trigger()
|
|
|
|
+ began_running = True
|
|
|
|
+
|
|
|
|
+ with self.lock_optimizer:
|
|
|
|
+ if optimizer_step:
|
|
|
|
+ with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
|
|
|
|
+ logger.debug(f"Running optimizer step")
|
|
|
|
+ if grad_scaler is None:
|
|
|
|
+ self.optimizer.step()
|
|
|
|
+ else:
|
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
|
+ assert grad_scaler.step(self.optimizer)
|
|
|
|
+
|
|
|
|
+ if zero_grad:
|
|
|
|
+ logger.debug(f"Running zero grad")
|
|
|
|
+ self.optimizer.zero_grad()
|
|
|
|
+ if self.offload_optimizer:
|
|
|
|
+ for parameter in self.main_parameters:
|
|
|
|
+ if parameter.grad is not None:
|
|
|
|
+ parameter.grad.zero_()
|
|
|
|
+
|
|
|
|
+ self._update_scheduler()
|
|
|
|
+ self.finished_optimizer_step.set()
|
|
|
|
|
|
if averaging_round:
|
|
if averaging_round:
|
|
- if not self.reuse_tensors:
|
|
|
|
- self._load_local_tensors_into_averager_()
|
|
|
|
- try:
|
|
|
|
- gathered = super().step(gather=self.local_epoch, **kwargs)
|
|
|
|
- logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
|
|
|
|
- except BaseException as e:
|
|
|
|
- logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
|
|
|
|
- self.finished_averaging_round.set()
|
|
|
|
- gathered = {}
|
|
|
|
|
|
+ with self.lock_averaging:
|
|
|
|
+ if not self.reuse_tensors:
|
|
|
|
+ self._load_local_tensors_into_averager_()
|
|
|
|
+ if self.delta_rule_averaging:
|
|
|
|
+ # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
|
|
|
|
+ with torch.no_grad(), self.get_tensors() as averaged_tensors:
|
|
|
|
+ self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
|
|
|
|
+
|
|
|
|
+ self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
|
|
|
|
+ try:
|
|
|
|
+ averaging_control.allow_allreduce()
|
|
|
|
+ gathered = averaging_control.result(timeout=timeout)
|
|
|
|
+ logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
|
|
|
|
+ gathered = {}
|
|
|
|
|
|
- self.finished_averaging_round.set()
|
|
|
|
|
|
+ self.finished_averaging_round.set()
|
|
|
|
|
|
if self.sync_epoch_when_averaging:
|
|
if self.sync_epoch_when_averaging:
|
|
old_epoch = self.local_epoch
|
|
old_epoch = self.local_epoch
|
|
@@ -433,7 +538,12 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
self._update_scheduler()
|
|
self._update_scheduler()
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
|
+ if not began_running:
|
|
|
|
+ logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
|
|
logger.exception(e)
|
|
logger.exception(e)
|
|
|
|
+ if averaging_control is not None and not averaging_control.done():
|
|
|
|
+ logger.error(f"Cancelled scheduled state averaging round")
|
|
|
|
+ averaging_control.cancel()
|
|
self.finished_optimizer_step.set()
|
|
self.finished_optimizer_step.set()
|
|
self.finished_averaging_round.set()
|
|
self.finished_averaging_round.set()
|
|
|
|
|
|
@@ -447,16 +557,13 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
opt_param.grad.copy_(main_param.grad, non_blocking=True)
|
|
opt_param.grad.copy_(main_param.grad, non_blocking=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
- def _apply_optimizer_results_(self):
|
|
|
|
|
|
+ def _apply_optimizer_parameters_(self):
|
|
"""Copy parameters from offloaded optimizer to the main model"""
|
|
"""Copy parameters from offloaded optimizer to the main model"""
|
|
assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
|
|
assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
|
|
- with self.lock_averaged_tensors:
|
|
|
|
- offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
|
- assert len(offloaded_parameters) == len(
|
|
|
|
- self.main_parameters
|
|
|
|
- ), "Optimizer parameters changed during training"
|
|
|
|
- for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
|
|
|
|
- main_param.copy_(offloaded_param, non_blocking=True)
|
|
|
|
|
|
+ offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
|
+ assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
|
|
|
|
+ for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
|
|
|
|
+ main_param.copy_(offloaded_param, non_blocking=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
def _load_local_tensors_into_averager_(self):
|
|
def _load_local_tensors_into_averager_(self):
|
|
@@ -470,18 +577,30 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
def _apply_averaging_results_(self):
|
|
def _apply_averaging_results_(self):
|
|
"""Copy averaged tensors into their respective local tensors"""
|
|
"""Copy averaged tensors into their respective local tensors"""
|
|
assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
|
|
assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
|
|
|
|
+ if self.delta_rule_averaging and self._old_tensors is None:
|
|
|
|
+ logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed.")
|
|
with self.get_tensors() as averaged_tensors:
|
|
with self.get_tensors() as averaged_tensors:
|
|
local_tensors = list(self._local_tensors())
|
|
local_tensors = list(self._local_tensors())
|
|
assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
|
|
assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
|
|
- for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
|
|
|
|
- local_tensor.copy_(averaged_tensor, non_blocking=True)
|
|
|
|
|
|
+ if not self.delta_rule_averaging or self._old_tensors is None:
|
|
|
|
+ for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
|
|
|
|
+ local_tensor.copy_(averaged_tensor, non_blocking=True)
|
|
|
|
+ else:
|
|
|
|
+ assert len(self._old_tensors) == len(local_tensors)
|
|
|
|
+ for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
|
|
|
|
+ delta = torch.sub(new_tensor, old_tensor, out=old_tensor) # using old tensors as buffers
|
|
|
|
+ local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def averaging_in_progress(self) -> bool:
|
|
|
|
+ return self.lock_averaging.locked()
|
|
|
|
|
|
def get_current_state(self):
|
|
def get_current_state(self):
|
|
"""
|
|
"""
|
|
Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
|
|
Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
|
|
:returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
|
|
:returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
|
|
"""
|
|
"""
|
|
- with torch.no_grad():
|
|
|
|
|
|
+ with torch.no_grad(), self.lock_averaged_tensors:
|
|
optimized_parameters = tuple(
|
|
optimized_parameters = tuple(
|
|
param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
|
|
param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
|
|
)
|
|
)
|
|
@@ -512,8 +631,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
|
|
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
|
|
:returns: whether or the averager succeeded in loading parameters
|
|
:returns: whether or the averager succeeded in loading parameters
|
|
"""
|
|
"""
|
|
- parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
|
|
|
|
- num_parameters_and_extras = len(parameters_and_extras)
|
|
|
|
|
|
+ main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
|
|
|
|
+ num_parameters_and_extras = len(main_parameters_and_extras)
|
|
|
|
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|
|
if loaded_state is None:
|
|
if loaded_state is None:
|
|
@@ -530,15 +649,19 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
logger.error("Failed to load state from peer, received parameters, extras or metadata.")
|
|
logger.error("Failed to load state from peer, received parameters, extras or metadata.")
|
|
return
|
|
return
|
|
|
|
|
|
- try:
|
|
|
|
- load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
|
|
|
|
- except StopIteration:
|
|
|
|
- logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
|
|
|
|
- return
|
|
|
|
|
|
+ with torch.no_grad(), self.lock_averaged_tensors:
|
|
|
|
+ try:
|
|
|
|
+ load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
|
|
|
|
+ except StopIteration:
|
|
|
|
+ logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
|
|
|
|
+ return
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
|
- for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
|
|
|
|
|
|
+ for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
|
|
local_param.copy_(loaded_param, non_blocking=True)
|
|
local_param.copy_(loaded_param, non_blocking=True)
|
|
|
|
+
|
|
|
|
+ if self.offload_optimizer:
|
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
|
+
|
|
self.local_epoch = metadata["epoch"]
|
|
self.local_epoch = metadata["epoch"]
|
|
self._update_scheduler()
|
|
self._update_scheduler()
|
|
|
|
|