|
@@ -36,7 +36,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
Example:
|
|
Example:
|
|
|
|
|
|
- >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
|
|
|
|
|
|
+ >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
|
|
>>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
|
|
>>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
|
|
>>> avgr.load_state_from_peers()
|
|
>>> avgr.load_state_from_peers()
|
|
>>> for i, batch in enumerate(training_dataloader):
|
|
>>> for i, batch in enumerate(training_dataloader):
|
|
@@ -49,7 +49,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
TrainingStateAverager.step(..., optimizer_step=True)
|
|
TrainingStateAverager.step(..., optimizer_step=True)
|
|
|
|
|
|
:param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
|
|
:param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
|
|
- :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
|
|
|
|
|
|
+ :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
|
|
:param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
|
|
:param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
|
|
:note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
|
|
:note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
|
|
:param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
|
|
:param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
|
|
@@ -61,7 +61,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
:param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
|
|
:param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
|
|
For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
|
|
For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
|
|
:param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
|
|
:param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
|
|
- :param parameter_names: optionally provide parameter names in the same order as param_groups
|
|
|
|
|
|
+ :param parameter_names: optionally provide parameter names in the same order as in params
|
|
:param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
|
|
:param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
|
|
:param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
|
|
:param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
|
|
:note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
|
|
:note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
|
|
@@ -73,7 +73,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
*,
|
|
*,
|
|
dht: hivemind.DHT,
|
|
dht: hivemind.DHT,
|
|
optimizer: Union[TorchOptimizer, OptimizerFactory],
|
|
optimizer: Union[TorchOptimizer, OptimizerFactory],
|
|
- param_groups: Optional[Union[Parameters, ParamGroups]] = None,
|
|
|
|
|
|
+ params: Optional[Union[Parameters, ParamGroups]] = None,
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
initialize_optimizer: Optional[bool] = None,
|
|
initialize_optimizer: Optional[bool] = None,
|
|
offload_optimizer: bool = False,
|
|
offload_optimizer: bool = False,
|
|
@@ -93,7 +93,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
if custom_gradients and not offload_optimizer:
|
|
if custom_gradients and not offload_optimizer:
|
|
logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
|
|
logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
|
|
|
|
|
|
- param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
|
|
|
|
|
|
+ params_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
|
|
|
|
|
|
self.status_loglevel = status_loglevel
|
|
self.status_loglevel = status_loglevel
|
|
self.reuse_tensors = reuse_tensors
|
|
self.reuse_tensors = reuse_tensors
|
|
@@ -103,7 +103,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
|
|
self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
- param_groups, optimizer, scheduler, initialize_optimizer
|
|
|
|
|
|
+ params_groups, optimizer, scheduler, initialize_optimizer
|
|
)
|
|
)
|
|
self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
|
|
self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
|
|
self.sync_epoch_when_averaging = sync_epoch_when_averaging
|
|
self.sync_epoch_when_averaging = sync_epoch_when_averaging
|