|
@@ -1,3 +1,5 @@
|
|
|
+from dataclasses import dataclass
|
|
|
+
|
|
|
import time
|
|
|
from threading import Thread, Lock, Event
|
|
|
from typing import Optional, Sequence, Tuple
|
|
@@ -7,10 +9,16 @@ import torch
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.client import TrainingAverager
|
|
|
from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
-from hivemind.utils import get_logger, get_dht_time
|
|
|
+from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
|
|
|
+
|
|
|
|
|
|
+@dataclass(frozen=False)
|
|
|
+class DecentralizedState:
|
|
|
+ max_epoch: int = 0
|
|
|
+ total_steps: int = 0
|
|
|
|
|
|
class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
"""
|
|
@@ -35,12 +43,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
|
|
|
- average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
|
|
|
+ average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
|
|
|
+ total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
|
|
|
+ scheduler: Optional[LRSchedulerBase] = None,
|
|
|
averaging_steps_period: int = 1, averaging_time_period: float = 0,
|
|
|
- timeout: Optional[float] = None, verbose: bool = False, **kwargs):
|
|
|
+ report_progress_expiration: int = 30, timeout: Optional[float] = None,
|
|
|
+ verbose: bool = False, **kwargs):
|
|
|
super().__init__(opt, dht)
|
|
|
assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
|
|
|
self.local_step, self.averaging_step_period = 0, averaging_steps_period
|
|
|
+ self.dht = dht
|
|
|
|
|
|
self.averager = TrainingAverager(opt, average_parameters=average_parameters,
|
|
|
average_gradients=average_gradients,
|
|
@@ -49,19 +61,61 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
target_group_size=target_group_size, **kwargs)
|
|
|
self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
|
|
|
|
|
|
+ self.scheduler = scheduler
|
|
|
+
|
|
|
+ self.epoch = 0
|
|
|
+ self.report_progress_expiration = report_progress_expiration
|
|
|
+ self.max_allowed_epoch_difference = max_allowed_epoch_difference
|
|
|
+ self.total_steps_in_epoch = total_steps_in_epoch
|
|
|
+ self.report_progress_key = f"{prefix}.progress"
|
|
|
+ self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
|
|
|
+ self.lock_scheduler_params = Lock()
|
|
|
+ self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
|
|
|
+ self.fetch_decentralized_state_event.set()
|
|
|
+ self._fetch_decentralized_state(initial=True)
|
|
|
+
|
|
|
self.background_averaging_thread = Thread(
|
|
|
name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
|
|
|
args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
|
|
|
kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose))
|
|
|
self.background_averaging_thread.start()
|
|
|
+ self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
|
|
|
+ self.background_report_progress.start()
|
|
|
+ self.background_fetch_decentralized_state = Thread(
|
|
|
+ name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
|
|
|
+ self.background_fetch_decentralized_state.start()
|
|
|
+
|
|
|
+
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
|
+ if not self.is_synchronized:
|
|
|
+ logger.warning("Peer is out of sync.")
|
|
|
+ self.load_states_from_peers(**kwargs)
|
|
|
+ return
|
|
|
+
|
|
|
+ with self.lock_scheduler_params:
|
|
|
+ if self.epoch < self.decentralized_state.max_epoch:
|
|
|
+ self.local_step = 0
|
|
|
+ self.epoch = self.decentralized_state.max_epoch
|
|
|
+
|
|
|
+ if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
|
|
|
+ self.epoch += 1
|
|
|
+ self.local_step = 0
|
|
|
+
|
|
|
+ if self.scheduler:
|
|
|
+ while self.epoch > self.scheduler._step_count:
|
|
|
+ self.scheduler.step()
|
|
|
+
|
|
|
with self.lock_parameters:
|
|
|
- loss = self.opt.step(*args, **kwargs)
|
|
|
+ step_result = self.opt.step(*args, **kwargs)
|
|
|
self.local_step += 1
|
|
|
+
|
|
|
if self.local_step % self.averaging_step_period == 0:
|
|
|
self.update_event.set()
|
|
|
- return loss
|
|
|
+ self.report_progress_event.set()
|
|
|
+ self.fetch_decentralized_state_event.set()
|
|
|
+
|
|
|
+ return step_result
|
|
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
|
return self.opt.zero_grad(*args, **kwargs)
|
|
@@ -75,6 +129,14 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.update_event.set()
|
|
|
self.averager.shutdown()
|
|
|
|
|
|
+ def load_states_from_peers(self, **kwargs):
|
|
|
+ logger.info("Trying to restore state from peers.")
|
|
|
+ with self.lock_parameters, self.lock_scheduler_params:
|
|
|
+ self.zero_grad()
|
|
|
+ self.averager.load_state_from_peers(**kwargs)
|
|
|
+ self.local_step = 0
|
|
|
+ self.epoch = self.decentralized_state.max_epoch
|
|
|
+
|
|
|
@staticmethod
|
|
|
@torch.no_grad()
|
|
|
def _average_parameters_in_background(
|
|
@@ -105,6 +167,58 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
except Exception as e:
|
|
|
logger.error(f"Averaging round failed: caught {e}.")
|
|
|
|
|
|
+ @property
|
|
|
+ def is_synchronized(self) -> bool:
|
|
|
+ return self.epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def _report_progress(self):
|
|
|
+ while not self.stop_event.is_set():
|
|
|
+ self.report_progress_event.wait()
|
|
|
+ self.report_progress_event.clear()
|
|
|
+ if self.stop_event.is_set():
|
|
|
+ break
|
|
|
+ current_time = get_dht_time()
|
|
|
+ with self.lock_scheduler_params:
|
|
|
+ local_state_info = [self.local_step, current_time, self.epoch]
|
|
|
+ self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_state_info,
|
|
|
+ expiration_time=current_time + self.report_progress_expiration, return_future=False)
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def _fetch_decentralized_state(self, initial: bool = False):
|
|
|
+ """ Read collaboration state reported by peers """
|
|
|
+ while not self.stop_event.is_set():
|
|
|
+ self.fetch_decentralized_state_event.wait()
|
|
|
+ self.fetch_decentralized_state_event.clear()
|
|
|
+ if self.stop_event.is_set():
|
|
|
+ break
|
|
|
+ response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
|
|
|
+ if not isinstance(response, dict) or len(response) == 0:
|
|
|
+ logger.info(f"Found no active peers: {response}")
|
|
|
+ with self.lock_scheduler_params:
|
|
|
+ self.decentralized_state = DecentralizedState(max_epoch=self.epoch, total_steps=self.local_step)
|
|
|
+ if initial:
|
|
|
+ break
|
|
|
+ continue
|
|
|
+
|
|
|
+ valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
|
|
|
+ num_peers = len(valid_peer_states)
|
|
|
+
|
|
|
+ with self.lock_scheduler_params:
|
|
|
+ global_epoch = self.epoch
|
|
|
+ for step, time, epoch in valid_peer_states:
|
|
|
+ global_epoch = max(global_epoch, epoch)
|
|
|
+
|
|
|
+ total_steps = 0
|
|
|
+ for step, time, epoch in valid_peer_states:
|
|
|
+ if epoch == global_epoch:
|
|
|
+ total_steps += step
|
|
|
+
|
|
|
+ self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
|
|
|
+
|
|
|
+ if initial:
|
|
|
+ break
|
|
|
+
|
|
|
|
|
|
class DecentralizedSGD(DecentralizedOptimizer):
|
|
|
"""
|