瀏覽代碼

added scheduler state loading and saveing to averager

xtinkt 4 年之前
父節點
當前提交
517018b136
共有 2 個文件被更改,包括 15 次插入4 次删除
  1. 13 3
      hivemind/client/averaging/training.py
  2. 2 1
      hivemind/optim/averaged.py

+ 13 - 3
hivemind/client/averaging/training.py

@@ -10,6 +10,7 @@ from hivemind.client.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
 
 
 class TrainingAverager(DecentralizedAverager):
 class TrainingAverager(DecentralizedAverager):
@@ -25,6 +26,7 @@ class TrainingAverager(DecentralizedAverager):
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
     :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
+    :param scheduler: if specified, averager puts scheduler state to current state
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
@@ -33,13 +35,14 @@ class TrainingAverager(DecentralizedAverager):
     """
     """
 
 
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
-                 initialize_optimizer: bool = True, **kwargs):
+                 average_opt_statistics: Sequence[str] = (), scheduler: Optional[LRSchedulerBase] = None,
+                 extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
 
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.lock_averager_step = Lock()
         self.lock_averager_step = Lock()
+        self.scheduler = scheduler
         if initialize_optimizer:
         if initialize_optimizer:
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
 
 
@@ -130,8 +133,10 @@ class TrainingAverager(DecentralizedAverager):
                                          for param in param_group['params'])
                                          for param in param_group['params'])
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+            scheduler_state = self.scheduler.state_dict() if self.scheduler else None
 
 
-        metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
+        metadata = dict(step=self.local_step, group_bits=self.get_group_bits(),
+                        optimizer_metadata=optimizer_metadata, scheduler_state=scheduler_state)
         return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
         return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
 
 
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
@@ -156,6 +161,11 @@ class TrainingAverager(DecentralizedAverager):
             load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
             load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
 
 
         self.local_step = max(self.local_step, metadata['step'])
         self.local_step = max(self.local_step, metadata['step'])
+        if self.scheduler:
+            if 'scheduler_state' not in metadata:
+                logger.warning("Scheduler is initialized, but there is no key 'scheduler_state' found in state")
+            else:
+                self.scheduler.load_state_dict(metadata['scheduler_state'])
 
 
 
 
 def initialize_optimizer_state(opt: torch.optim.Optimizer):
 def initialize_optimizer_state(opt: torch.optim.Optimizer):

+ 2 - 1
hivemind/optim/averaged.py

@@ -58,14 +58,15 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.local_step, self.averaging_step_period = 0, averaging_steps_period
         self.local_step, self.averaging_step_period = 0, averaging_steps_period
         self.dht = dht
         self.dht = dht
 
 
+        self.scheduler = None if scheduler_cls is None else scheduler_cls(opt)
         self.averager = TrainingAverager(opt, average_parameters=average_parameters,
         self.averager = TrainingAverager(opt, average_parameters=average_parameters,
                                          average_gradients=average_gradients,
                                          average_gradients=average_gradients,
                                          average_opt_statistics=average_opt_statistics,
                                          average_opt_statistics=average_opt_statistics,
+                                         scheduler=self.scheduler,
                                          dht=dht, start=True, prefix=prefix,
                                          dht=dht, start=True, prefix=prefix,
                                          target_group_size=target_group_size, **kwargs)
                                          target_group_size=target_group_size, **kwargs)
 
 
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
-        self.scheduler = None if scheduler_cls is None else scheduler_cls(opt)
         self.local_epoch = 0
         self.local_epoch = 0
         self.report_progress_expiration = report_progress_expiration
         self.report_progress_expiration = report_progress_expiration
         self.max_allowed_epoch_difference = max_allowed_epoch_difference
         self.max_allowed_epoch_difference = max_allowed_epoch_difference