Browse Source

add decentralized learning rate scheduler and epochs abstraction

xtinkt 4 years ago
parent
commit
de01a9cb15
1 changed files with 119 additions and 5 deletions
  1. 119 5
      hivemind/optim/decentralized_optimizers.py

+ 119 - 5
hivemind/optim/simple.py → hivemind/optim/decentralized_optimizers.py

@@ -1,3 +1,5 @@
+from dataclasses import dataclass
+
 import time
 from threading import Thread, Lock, Event
 from typing import Optional, Sequence, Tuple
@@ -7,10 +9,16 @@ import torch
 from hivemind.dht import DHT
 from hivemind.client import TrainingAverager
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.utils import get_logger, get_dht_time
+from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 
 logger = get_logger(__name__)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+
 
+@dataclass(frozen=False)
+class DecentralizedState:
+    max_epoch: int = 0
+    total_steps: int = 0
 
 class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
@@ -35,12 +43,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
-                 average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
+                 average_parameters: bool, average_gradients: bool, max_allowed_epoch_difference: int,
+                 total_steps_in_epoch: int, average_opt_statistics: Sequence[str] = (),
+                 scheduler: Optional[LRSchedulerBase] = None,
                  averaging_steps_period: int = 1, averaging_time_period: float = 0,
-                 timeout: Optional[float] = None, verbose: bool = False, **kwargs):
+                 report_progress_expiration: int = 30, timeout: Optional[float] = None,
+                 verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
         assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
         self.local_step, self.averaging_step_period = 0, averaging_steps_period
+        self.dht = dht
 
         self.averager = TrainingAverager(opt, average_parameters=average_parameters,
                                          average_gradients=average_gradients,
@@ -49,19 +61,61 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                                          target_group_size=target_group_size, **kwargs)
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
 
+        self.scheduler = scheduler
+
+        self.epoch = 0
+        self.report_progress_expiration = report_progress_expiration
+        self.max_allowed_epoch_difference = max_allowed_epoch_difference
+        self.total_steps_in_epoch = total_steps_in_epoch
+        self.report_progress_key = f"{prefix}.progress"
+        self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
+        self.lock_scheduler_params = Lock()
+        self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
+        self.fetch_decentralized_state_event.set()
+        self._fetch_decentralized_state(initial=True)
+
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
             args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
             kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose))
         self.background_averaging_thread.start()
+        self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
+        self.background_report_progress.start()
+        self.background_fetch_decentralized_state = Thread(
+            name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state)
+        self.background_fetch_decentralized_state.start()
+
+
 
     def step(self, *args, **kwargs):
+        if not self.is_synchronized:
+            logger.warning("Peer is out of sync.")
+            self.load_states_from_peers(**kwargs)
+            return
+
+        with self.lock_scheduler_params:
+            if self.epoch < self.decentralized_state.max_epoch:
+                self.local_step = 0
+                self.epoch = self.decentralized_state.max_epoch
+
+            if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
+                self.epoch += 1
+                self.local_step = 0
+
+            if self.scheduler:
+                while self.epoch > self.scheduler._step_count:
+                    self.scheduler.step()
+
         with self.lock_parameters:
-            loss = self.opt.step(*args, **kwargs)
+            step_result = self.opt.step(*args, **kwargs)
         self.local_step += 1
+
         if self.local_step % self.averaging_step_period == 0:
             self.update_event.set()
-        return loss
+        self.report_progress_event.set()
+        self.fetch_decentralized_state_event.set()
+
+        return step_result
 
     def zero_grad(self, *args, **kwargs):
         return self.opt.zero_grad(*args, **kwargs)
@@ -75,6 +129,14 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.update_event.set()
         self.averager.shutdown()
 
+    def load_states_from_peers(self, **kwargs):
+        logger.info("Trying to restore state from peers.")
+        with self.lock_parameters, self.lock_scheduler_params:
+            self.zero_grad()
+            self.averager.load_state_from_peers(**kwargs)
+            self.local_step = 0
+            self.epoch = self.decentralized_state.max_epoch
+
     @staticmethod
     @torch.no_grad()
     def _average_parameters_in_background(
@@ -105,6 +167,58 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             except Exception as e:
                 logger.error(f"Averaging round failed: caught {e}.")
 
+    @property
+    def is_synchronized(self) -> bool:
+        return self.epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
+
+    @torch.no_grad()
+    def _report_progress(self):
+        while not self.stop_event.is_set(): 
+            self.report_progress_event.wait()
+            self.report_progress_event.clear()
+            if self.stop_event.is_set():
+                break
+            current_time = get_dht_time()
+            with self.lock_scheduler_params:
+                local_state_info = [self.local_step, current_time, self.epoch]
+            self.dht.store(key=self.report_progress_key, subkey=self.averager.endpoint, value=local_state_info,
+                           expiration_time=current_time + self.report_progress_expiration, return_future=False)
+
+    @torch.no_grad()
+    def _fetch_decentralized_state(self, initial: bool = False):
+        """ Read collaboration state reported by peers """
+        while not self.stop_event.is_set():
+            self.fetch_decentralized_state_event.wait()
+            self.fetch_decentralized_state_event.clear()
+            if self.stop_event.is_set():
+                break
+            response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
+            if not isinstance(response, dict) or len(response) == 0:
+                logger.info(f"Found no active peers: {response}")
+                with self.lock_scheduler_params:
+                    self.decentralized_state = DecentralizedState(max_epoch=self.epoch, total_steps=self.local_step)
+                    if initial:
+                        break
+                    continue
+
+            valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
+            num_peers = len(valid_peer_states)
+
+            with self.lock_scheduler_params:
+                global_epoch = self.epoch
+                for step, time, epoch in valid_peer_states:
+                    global_epoch = max(global_epoch, epoch)
+
+                total_steps = 0
+                for step, time, epoch in valid_peer_states:
+                    if epoch == global_epoch:
+                        total_steps += step
+
+                self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
+
+                if initial:
+                    break
+
 
 class DecentralizedSGD(DecentralizedOptimizer):
     """