|
@@ -10,6 +10,7 @@ from hivemind.client.averaging import DecentralizedAverager
|
|
|
from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
|
|
|
|
|
|
|
|
|
class TrainingAverager(DecentralizedAverager):
|
|
@@ -25,6 +26,7 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
:param average_parameters: whether or not to average model parameters in self.step(...)
|
|
|
:param average_gradients: whether or not to average model gradients in self.step(...)
|
|
|
:param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
|
|
|
+ :param scheduler: if specified, averager puts scheduler state to current state
|
|
|
:param initialize_optimizer: if True, this will run a speculative optimizer step with
|
|
|
zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
|
|
|
:param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
|
|
@@ -33,13 +35,14 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
|
|
|
- average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
- initialize_optimizer: bool = True, **kwargs):
|
|
|
+ average_opt_statistics: Sequence[str] = (), scheduler: Optional[LRSchedulerBase] = None,
|
|
|
+ extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
|
|
|
|
|
|
self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
|
|
|
self.opt_statistics = tuple(average_opt_statistics)
|
|
|
self.average_parameters, self.average_gradients = average_parameters, average_gradients
|
|
|
self.lock_averager_step = Lock()
|
|
|
+ self.scheduler = scheduler
|
|
|
if initialize_optimizer:
|
|
|
initialize_optimizer_state(opt) # note: this will run one optimizer step!
|
|
|
|
|
@@ -130,8 +133,10 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
for param in param_group['params'])
|
|
|
extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
|
|
|
optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
|
|
|
+ scheduler_state = self.scheduler.state_dict() if self.scheduler else None
|
|
|
|
|
|
- metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
|
|
|
+ metadata = dict(step=self.local_step, group_bits=self.get_group_bits(),
|
|
|
+ optimizer_metadata=optimizer_metadata, scheduler_state=scheduler_state)
|
|
|
return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
@@ -156,6 +161,11 @@ class TrainingAverager(DecentralizedAverager):
|
|
|
load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
|
|
|
|
|
|
self.local_step = max(self.local_step, metadata['step'])
|
|
|
+ if self.scheduler:
|
|
|
+ if 'scheduler_state' not in metadata:
|
|
|
+ logger.warning("Scheduler is initialized, but there is no key 'scheduler_state' found in state")
|
|
|
+ else:
|
|
|
+ self.scheduler.load_state_dict(metadata['scheduler_state'])
|
|
|
|
|
|
|
|
|
def initialize_optimizer_state(opt: torch.optim.Optimizer):
|