|
@@ -12,7 +12,6 @@ from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
|
|
|
|
|
|
|
|
|
@dataclass(frozen=False)
|
|
@@ -27,15 +26,21 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
:param opt: a pytorch optimizer configured to update model parameters.
|
|
|
:param dht: a running hivemind DHT daemon connected to other peers
|
|
|
+ :param prefix: all DHT keys that point to optimization metadata will have this prefix
|
|
|
+ :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
|
|
|
:param average_parameters: whether to average model parameters
|
|
|
:param average_gradients: whether to average gradients
|
|
|
+ :param max_allowed_epoch_difference: if max_epoch has difference with local_epoch mote than that, we download state
|
|
|
+ from other peer.
|
|
|
+ :param total_steps_in_epoch: how many total steps must be to increase local_epoch by one
|
|
|
:param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
|
|
|
+ :param scheduler_cls: lambda with opt in argument which returns learning rate scheduler
|
|
|
:param averaging_steps_period: performs averaging after this many optimizer steps
|
|
|
:param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
|
|
|
many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
|
|
|
- :param prefix: all DHT keys that point to optimization metadata will have this prefix
|
|
|
- :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
|
|
|
+ report_progress_expiration
|
|
|
:param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
|
|
|
+ :param verbose: verbose info
|
|
|
:param kwargs: additional parameters passed to TrainingAverager
|
|
|
:note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
|
|
|
necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
|
|
@@ -45,8 +50,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
|
|
|
average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
|
|
|
total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
|
|
|
- scheduler: Optional[LRSchedulerBase] = None,
|
|
|
- averaging_steps_period: int = 1, averaging_time_period: float = 0,
|
|
|
+ scheduler_cls = None, averaging_steps_period: int = 1, averaging_time_period: float = 0,
|
|
|
report_progress_expiration: int = 30, timeout: Optional[float] = None,
|
|
|
verbose: bool = False, **kwargs):
|
|
|
super().__init__(opt, dht)
|
|
@@ -61,9 +65,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
target_group_size=target_group_size, **kwargs)
|
|
|
self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
|
|
|
|
|
|
- self.scheduler = scheduler
|
|
|
+ if scheduler_cls:
|
|
|
+ self.scheduler = scheduler_cls(opt)
|
|
|
|
|
|
- self.epoch = 0
|
|
|
+ self.local_epoch = 0
|
|
|
self.report_progress_expiration = report_progress_expiration
|
|
|
self.max_allowed_epoch_difference = max_allowed_epoch_difference
|
|
|
self.total_steps_in_epoch = total_steps_in_epoch
|
|
@@ -85,8 +90,6 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
|
|
|
self.background_fetch_decentralized_state.start()
|
|
|
|
|
|
-
|
|
|
-
|
|
|
def step(self, *args, **kwargs):
|
|
|
if not self.is_synchronized:
|
|
|
logger.warning("Peer is out of sync.")
|
|
@@ -94,16 +97,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
return
|
|
|
|
|
|
with self.lock_scheduler_params:
|
|
|
- if self.epoch < self.decentralized_state.max_epoch:
|
|
|
+ if self.local_epoch < self.decentralized_state.max_epoch:
|
|
|
self.local_step = 0
|
|
|
- self.epoch = self.decentralized_state.max_epoch
|
|
|
+ self.local_epoch = self.decentralized_state.max_epoch
|
|
|
|
|
|
if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
|
|
|
- self.epoch += 1
|
|
|
+ self.local_epoch += 1
|
|
|
self.local_step = 0
|
|
|
|
|
|
if self.scheduler:
|
|
|
- while self.epoch > self.scheduler._step_count:
|
|
|
+ while self.local_epoch > self.scheduler._step_count:
|
|
|
self.scheduler.step()
|
|
|
|
|
|
with self.lock_parameters:
|
|
@@ -135,7 +138,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.zero_grad()
|
|
|
self.averager.load_state_from_peers(**kwargs)
|
|
|
self.local_step = 0
|
|
|
- self.epoch = self.decentralized_state.max_epoch
|
|
|
+ self.local_epoch = self.decentralized_state.max_epoch
|
|
|
|
|
|
@staticmethod
|
|
|
@torch.no_grad()
|
|
@@ -169,7 +172,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
@property
|
|
|
def is_synchronized(self) -> bool:
|
|
|
- return self.epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
|
|
|
+ return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _report_progress(self):
|
|
@@ -180,7 +183,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
break
|
|
|
current_time = get_dht_time()
|
|
|
with self.lock_scheduler_params:
|
|
|
- local_state_info = [self.local_step, current_time, self.epoch]
|
|
|
+ local_state_info = [self.local_step, current_time, self.local_epoch]
|
|
|
self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_state_info,
|
|
|
expiration_time=current_time + self.report_progress_expiration, return_future=False)
|
|
|
|
|
@@ -196,7 +199,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
if not isinstance(response, dict) or len(response) == 0:
|
|
|
logger.info(f"Found no active peers: {response}")
|
|
|
with self.lock_scheduler_params:
|
|
|
- self.decentralized_state = DecentralizedState(max_epoch=self.epoch, total_steps=self.local_step)
|
|
|
+ self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
|
|
|
if initial:
|
|
|
break
|
|
|
continue
|
|
@@ -205,7 +208,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
num_peers = len(valid_peer_states)
|
|
|
|
|
|
with self.lock_scheduler_params:
|
|
|
- global_epoch = self.epoch
|
|
|
+ global_epoch = self.local_epoch
|
|
|
for step, time, epoch in valid_peer_states:
|
|
|
global_epoch = max(global_epoch, epoch)
|
|
|
|