瀏覽代碼

Merge branch 'decentralized_lr_scheduler' of https://github.com/learning-at-home/hivemind into decentralized_lr_scheduler

xtinkt 4 年之前
父節點
當前提交
eedfae4bd4
共有 2 個文件被更改,包括 3 次插入3 次删除
  1. 1 1
      hivemind/client/averaging/training.py
  2. 2 2
      hivemind/optim/averaged.py

+ 1 - 1
hivemind/client/averaging/training.py

@@ -132,7 +132,7 @@ class TrainingAverager(DecentralizedAverager):
                                          for param in param_group['params'])
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
-            scheduler_state = self.scheduler.state_dict() if self.scheduler else None
+            scheduler_state = self.scheduler.state_dict() if self.scheduler is not None else None
 
         metadata = dict(step=self.local_step, group_bits=self.get_group_bits(),
                         optimizer_metadata=optimizer_metadata, scheduler_state=scheduler_state)

+ 2 - 2
hivemind/optim/averaged.py

@@ -32,13 +32,13 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     :param average_gradients: whether to average gradients
     :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch more than that, we download state
       from other peer.
-    :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
+    :param total_steps_in_epoch: the number of optimizer steps for a single training epoch
     :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
     :param scheduler_cls: a function which takes an optimizer and returns a learning rate scheduler
     :param averaging_steps_period: performs averaging after this many optimizer steps
     :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
       many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    :param  report_progress_expiration: decentralized state time to live in dht
+    :param report_progress_expiration: decentralized state time to live in dht
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
     :param verbose: if True, outputs additional information during averaging
     :param kwargs: additional parameters passed to TrainingAverager