Browse Source

fix some pr issues

xtinkt 4 years ago
parent
commit
42423cfe6b
2 changed files with 9 additions and 11 deletions
  1. 4 5
      hivemind/client/averaging/training.py
  2. 5 6
      hivemind/optim/averaged.py

+ 4 - 5
hivemind/client/averaging/training.py

@@ -10,7 +10,6 @@ from hivemind.client.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 
 
 
 
 class TrainingAverager(DecentralizedAverager):
 class TrainingAverager(DecentralizedAverager):
@@ -26,7 +25,7 @@ class TrainingAverager(DecentralizedAverager):
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
     :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
-    :param scheduler: if specified, averager puts scheduler state to current state
+    :param scheduler: if specified, averager keeps scheduler state
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
@@ -35,7 +34,7 @@ class TrainingAverager(DecentralizedAverager):
     """
     """
 
 
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 average_opt_statistics: Sequence[str] = (), scheduler: Optional[LRSchedulerBase] = None,
+                 average_opt_statistics: Sequence[str] = (), scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
                  extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
                  extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
 
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
@@ -161,9 +160,9 @@ class TrainingAverager(DecentralizedAverager):
             load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
             load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
 
 
         self.local_step = max(self.local_step, metadata['step'])
         self.local_step = max(self.local_step, metadata['step'])
-        if self.scheduler:
+        if self.scheduler is not None:
             if 'scheduler_state' not in metadata:
             if 'scheduler_state' not in metadata:
-                logger.warning("Scheduler is initialized, but there is no key 'scheduler_state' found in state")
+                logger.warning("Scheduler was passed, but there is no key 'scheduler_state' found in state")
             else:
             else:
                 self.scheduler.load_state_dict(metadata['scheduler_state'])
                 self.scheduler.load_state_dict(metadata['scheduler_state'])
 
 

+ 5 - 6
hivemind/optim/averaged.py

@@ -30,7 +30,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param average_parameters: whether to average model parameters
     :param average_parameters: whether to average model parameters
     :param average_gradients: whether to average gradients
     :param average_gradients: whether to average gradients
-    :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch mote than that, we download state
+    :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch more than that, we download state
       from other peer.
       from other peer.
     :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
     :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
     :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
     :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
@@ -38,7 +38,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     :param averaging_steps_period: performs averaging after this many optimizer steps
     :param averaging_steps_period: performs averaging after this many optimizer steps
     :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
     :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
       many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
       many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    report_progress_expiration
+    :param  report_progress_expiration: decentralized state time to live in dht
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
     :param verbose: verbose info
     :param verbose: verbose info
     :param kwargs: additional parameters passed to TrainingAverager
     :param kwargs: additional parameters passed to TrainingAverager
@@ -48,8 +48,8 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
     """
 
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
-                 average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int = 1,
-                 total_steps_in_epoch: int = 1000, average_opt_statistics: Sequence[str] = (),
+                 average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
+                 max_allowed_epoch_difference: int = 1, total_steps_in_epoch: int = 1000,
                  scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
                  scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
                  report_progress_expiration: int = 30, timeout: Optional[float] = None,
                  report_progress_expiration: int = 30, timeout: Optional[float] = None,
                  verbose: bool = False, **kwargs):
                  verbose: bool = False, **kwargs):
@@ -131,7 +131,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.averager.shutdown()
         self.averager.shutdown()
 
 
     def load_states_from_peers(self, **kwargs):
     def load_states_from_peers(self, **kwargs):
-        logger.info("Trying to restore state from peers.")
+        logger.debug("Trying to restore state from peers.")
         with self.lock_parameters, self.lock_scheduler_params:
         with self.lock_parameters, self.lock_scheduler_params:
             self.zero_grad()
             self.zero_grad()
             self.averager.load_state_from_peers(**kwargs)
             self.averager.load_state_from_peers(**kwargs)
@@ -168,7 +168,6 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             except Exception as e:
             except Exception as e:
                 logger.error(f"Averaging round failed: caught {e}.")
                 logger.error(f"Averaging round failed: caught {e}.")
 
 
-    @property
     def is_synchronized(self) -> bool:
     def is_synchronized(self) -> bool:
         return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
         return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch