|
@@ -94,7 +94,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
if custom_gradients and not offload_optimizer:
|
|
|
logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
|
|
|
|
|
|
- params_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
|
|
|
+ param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
|
|
|
|
|
|
self.status_loglevel = status_loglevel
|
|
|
self.reuse_tensors = reuse_tensors
|
|
@@ -104,7 +104,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
|
self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
|
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
|
- params_groups, optimizer, scheduler, initialize_optimizer
|
|
|
+ param_groups, optimizer, scheduler, initialize_optimizer
|
|
|
)
|
|
|
self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
|
|
|
self.sync_epoch_when_averaging = sync_epoch_when_averaging
|