Browse Source

change DecentralizedState to TrainingState

xtinkt 4 years ago
parent
commit
017fbe43a2
1 changed files with 20 additions and 20 deletions
  1. 20 20
      hivemind/optim/averaged.py

+ 20 - 20
hivemind/optim/averaged.py

@@ -15,7 +15,7 @@ logger = get_logger(__name__)
 
 
 @dataclass(frozen=False)
-class DecentralizedState:
+class TrainingState:
     max_epoch: int = 0
     total_steps: int = 0
 
@@ -72,10 +72,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.max_allowed_epoch_difference = max_allowed_epoch_difference
         self.total_steps_in_epoch = total_steps_in_epoch
         self.report_progress_key = f"{prefix}.progress"
-        self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
+        self.report_progress_event, self.fetch_training_state_event = Event(), Event()
         self.lock_scheduler_params = Lock()
-        self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
-        self._fetch_decentralized_state()
+        self.training_state = TrainingState(max_epoch=0, total_steps=0)
+        self._fetch_training_state()
 
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
@@ -84,9 +84,9 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.background_averaging_thread.start()
         self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
         self.background_report_progress.start()
-        self.background_fetch_decentralized_state = Thread(
-            name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state_periodically)
-        self.background_fetch_decentralized_state.start()
+        self.background_fetch_training_state = Thread(
+            name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_training_state_periodically)
+        self.background_fetch_training_state.start()
 
     def step(self, *args, **kwargs):
         if not self.is_synchronized:
@@ -95,11 +95,11 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             return
 
         with self.lock_scheduler_params:
-            if self.local_epoch < self.decentralized_state.max_epoch:
+            if self.local_epoch < self.training_state.max_epoch:
                 self.local_step = 0
-                self.local_epoch = self.decentralized_state.max_epoch
+                self.local_epoch = self.training_state.max_epoch
 
-            if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
+            if self.training_state.total_steps >= self.total_steps_in_epoch:
                 self.local_epoch += 1
                 self.local_step = 0
 
@@ -114,7 +114,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         if self.local_step % self.averaging_step_period == 0:
             self.update_event.set()
         self.report_progress_event.set()
-        self.fetch_decentralized_state_event.set()
+        self.fetch_training_state_event.set()
 
         return step_result
 
@@ -136,7 +136,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             self.zero_grad()
             self.averager.load_state_from_peers(**kwargs)
             self.local_step = 0
-            self.local_epoch = self.decentralized_state.max_epoch
+            self.local_epoch = self.training_state.max_epoch
 
     @staticmethod
     @torch.no_grad()
@@ -169,7 +169,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                 logger.error(f"Averaging round failed: caught {e}.")
 
     def is_synchronized(self) -> bool:
-        return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
+        return self.local_epoch + self.max_allowed_epoch_difference >= self.training_state.max_epoch
 
     @torch.no_grad()
     def _report_progress(self):
@@ -185,23 +185,23 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                            expiration_time=current_time + self.report_progress_expiration, return_future=False)
 
     @torch.no_grad()
-    def _fetch_decentralized_state_periodically(self):
+    def _fetch_training_state_periodically(self):
         """ Read decentralized state loop """
         while not self.stop_event.is_set():
-            self.fetch_decentralized_state_event.wait()
-            self.fetch_decentralized_state_event.clear()
+            self.fetch_training_state_event.wait()
+            self.fetch_training_state_event.clear()
             if self.stop_event.is_set():
                 break
-            self._fetch_decentralized_state()
+            self._fetch_training_state()
 
     @torch.no_grad()
-    def _fetch_decentralized_state(self):
+    def _fetch_training_state(self):
         """ Read decentralized state reported by peers """
         response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
         if not isinstance(response, dict) or len(response) == 0:
             logger.info(f"Found no active peers: {response}")
             with self.lock_scheduler_params:
-                self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
+                self.training_state = TrainingState(max_epoch=self.local_epoch, total_steps=self.local_step)
                 return
 
         valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
@@ -214,7 +214,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             for step, time, epoch in valid_peer_states:
                 if epoch == global_epoch:
                     total_steps += step
-            self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
+            self.training_state = TrainingState(max_epoch=global_epoch, total_steps=total_steps)
 
 
 class DecentralizedSGD(DecentralizedOptimizer):