Kaynağa Gözat

fix some pr issues

xtinkt 4 yıl önce
ebeveyn
işleme
42423cfe6b

+ 4 - 5
hivemind/client/averaging/training.py

@@ -10,7 +10,6 @@ from hivemind.client.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 
 logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
 class TrainingAverager(DecentralizedAverager):
@@ -26,7 +25,7 @@ class TrainingAverager(DecentralizedAverager):
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
-    :param scheduler: if specified, averager puts scheduler state to current state
+    :param scheduler: if specified, averager keeps scheduler state
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
@@ -35,7 +34,7 @@ class TrainingAverager(DecentralizedAverager):
     """
 
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 average_opt_statistics: Sequence[str] = (), scheduler: Optional[LRSchedulerBase] = None,
+                 average_opt_statistics: Sequence[str] = (), scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
                  extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
@@ -161,9 +160,9 @@ class TrainingAverager(DecentralizedAverager):
             load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
 
         self.local_step = max(self.local_step, metadata['step'])
-        if self.scheduler:
+        if self.scheduler is not None:
             if 'scheduler_state' not in metadata:
-                logger.warning("Scheduler is initialized, but there is no key 'scheduler_state' found in state")
+                logger.warning("Scheduler was passed, but there is no key 'scheduler_state' found in state")
             else:
                 self.scheduler.load_state_dict(metadata['scheduler_state'])
 

+ 5 - 6
hivemind/optim/averaged.py

@@ -30,7 +30,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param average_parameters: whether to average model parameters
     :param average_gradients: whether to average gradients
-    :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch mote than that, we download state
+    :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch more than that, we download state
       from other peer.
     :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
     :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
@@ -38,7 +38,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     :param averaging_steps_period: performs averaging after this many optimizer steps
     :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
       many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    report_progress_expiration
+    :param  report_progress_expiration: decentralized state time to live in dht
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
     :param verbose: verbose info
     :param kwargs: additional parameters passed to TrainingAverager
@@ -48,8 +48,8 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
-                 average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int = 1,
-                 total_steps_in_epoch: int = 1000, average_opt_statistics: Sequence[str] = (),
+                 average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
+                 max_allowed_epoch_difference: int = 1, total_steps_in_epoch: int = 1000,
                  scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
                  report_progress_expiration: int = 30, timeout: Optional[float] = None,
                  verbose: bool = False, **kwargs):
@@ -131,7 +131,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.averager.shutdown()
 
     def load_states_from_peers(self, **kwargs):
-        logger.info("Trying to restore state from peers.")
+        logger.debug("Trying to restore state from peers.")
         with self.lock_parameters, self.lock_scheduler_params:
             self.zero_grad()
             self.averager.load_state_from_peers(**kwargs)
@@ -168,7 +168,6 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             except Exception as e:
                 logger.error(f"Averaging round failed: caught {e}.")
 
-    @property
     def is_synchronized(self) -> bool:
         return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch