소스 검색

added scheduler state loading and saveing to averager

xtinkt 4 년 전
부모
커밋
517018b136
2개의 변경된 파일15개의 추가작업 그리고 4개의 파일을 삭제
  1. 13 3
      hivemind/client/averaging/training.py
  2. 2 1
      hivemind/optim/averaged.py

+ 13 - 3
hivemind/client/averaging/training.py

@@ -10,6 +10,7 @@ from hivemind.client.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 
 logger = get_logger(__name__)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
 class TrainingAverager(DecentralizedAverager):
@@ -25,6 +26,7 @@ class TrainingAverager(DecentralizedAverager):
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
+    :param scheduler: if specified, averager puts scheduler state to current state
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
@@ -33,13 +35,14 @@ class TrainingAverager(DecentralizedAverager):
     """
 
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
-                 initialize_optimizer: bool = True, **kwargs):
+                 average_opt_statistics: Sequence[str] = (), scheduler: Optional[LRSchedulerBase] = None,
+                 extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.lock_averager_step = Lock()
+        self.scheduler = scheduler
         if initialize_optimizer:
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
 
@@ -130,8 +133,10 @@ class TrainingAverager(DecentralizedAverager):
                                          for param in param_group['params'])
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+            scheduler_state = self.scheduler.state_dict() if self.scheduler else None
 
-        metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
+        metadata = dict(step=self.local_step, group_bits=self.get_group_bits(),
+                        optimizer_metadata=optimizer_metadata, scheduler_state=scheduler_state)
         return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
 
     def load_state_from_peers(self, **kwargs):
@@ -156,6 +161,11 @@ class TrainingAverager(DecentralizedAverager):
             load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
 
         self.local_step = max(self.local_step, metadata['step'])
+        if self.scheduler:
+            if 'scheduler_state' not in metadata:
+                logger.warning("Scheduler is initialized, but there is no key 'scheduler_state' found in state")
+            else:
+                self.scheduler.load_state_dict(metadata['scheduler_state'])
 
 
 def initialize_optimizer_state(opt: torch.optim.Optimizer):

+ 2 - 1
hivemind/optim/averaged.py

@@ -58,14 +58,15 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.local_step, self.averaging_step_period = 0, averaging_steps_period
         self.dht = dht
 
+        self.scheduler = None if scheduler_cls is None else scheduler_cls(opt)
         self.averager = TrainingAverager(opt, average_parameters=average_parameters,
                                          average_gradients=average_gradients,
                                          average_opt_statistics=average_opt_statistics,
+                                         scheduler=self.scheduler,
                                          dht=dht, start=True, prefix=prefix,
                                          target_group_size=target_group_size, **kwargs)
 
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
-        self.scheduler = None if scheduler_cls is None else scheduler_cls(opt)
         self.local_epoch = 0
         self.report_progress_expiration = report_progress_expiration
         self.max_allowed_epoch_difference = max_allowed_epoch_difference