|
@@ -30,7 +30,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
:param target_group_size: maximum group size for averaging (see DecentralizedAverager)
|
|
|
:param average_parameters: whether to average model parameters
|
|
|
:param average_gradients: whether to average gradients
|
|
|
- :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch mote than that, we download state
|
|
|
+ :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch more than that, we download state
|
|
|
from other peer.
|
|
|
:param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
|
|
|
:param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
|
|
@@ -38,7 +38,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
:param averaging_steps_period: performs averaging after this many optimizer steps
|
|
|
:param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
|
|
|
many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
|
|
|
- report_progress_expiration
|
|
|
+ :param report_progress_expiration: decentralized state time to live in dht
|
|
|
:param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
|
|
|
:param verbose: verbose info
|
|
|
:param kwargs: additional parameters passed to TrainingAverager
|
|
@@ -48,8 +48,8 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
|
|
|
- average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int = 1,
|
|
|
- total_steps_in_epoch: int = 1000, average_opt_statistics: Sequence[str] = (),
|
|
|
+ average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
|
|
|
+ max_allowed_epoch_difference: int = 1, total_steps_in_epoch: int = 1000,
|
|
|
scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
|
|
|
report_progress_expiration: int = 30, timeout: Optional[float] = None,
|
|
|
verbose: bool = False, **kwargs):
|
|
@@ -131,7 +131,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.averager.shutdown()
|
|
|
|
|
|
def load_states_from_peers(self, **kwargs):
|
|
|
- logger.info("Trying to restore state from peers.")
|
|
|
+ logger.debug("Trying to restore state from peers.")
|
|
|
with self.lock_parameters, self.lock_scheduler_params:
|
|
|
self.zero_grad()
|
|
|
self.averager.load_state_from_peers(**kwargs)
|
|
@@ -168,7 +168,6 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
except Exception as e:
|
|
|
logger.error(f"Averaging round failed: caught {e}.")
|
|
|
|
|
|
- @property
|
|
|
def is_synchronized(self) -> bool:
|
|
|
return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
|
|
|
|