|
|
@@ -15,7 +15,7 @@ logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
@dataclass(frozen=False)
|
|
|
-class DecentralizedState:
|
|
|
+class TrainingState:
|
|
|
max_epoch: int = 0
|
|
|
total_steps: int = 0
|
|
|
|
|
|
@@ -72,10 +72,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.max_allowed_epoch_difference = max_allowed_epoch_difference
|
|
|
self.total_steps_in_epoch = total_steps_in_epoch
|
|
|
self.report_progress_key = f"{prefix}.progress"
|
|
|
- self.report_progress_event, self.fetch_decentralized_state_event = Event(), Event()
|
|
|
+ self.report_progress_event, self.fetch_training_state_event = Event(), Event()
|
|
|
self.lock_scheduler_params = Lock()
|
|
|
- self.decentralized_state = DecentralizedState(max_epoch=0, total_steps=0)
|
|
|
- self._fetch_decentralized_state()
|
|
|
+ self.training_state = TrainingState(max_epoch=0, total_steps=0)
|
|
|
+ self._fetch_training_state()
|
|
|
|
|
|
self.background_averaging_thread = Thread(
|
|
|
name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
|
|
|
@@ -84,9 +84,9 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.background_averaging_thread.start()
|
|
|
self.background_report_progress = Thread(name=f'{self.__class__.__name__}.reporter', daemon=True, target=self._report_progress)
|
|
|
self.background_report_progress.start()
|
|
|
- self.background_fetch_decentralized_state = Thread(
|
|
|
- name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_decentralized_state_periodically)
|
|
|
- self.background_fetch_decentralized_state.start()
|
|
|
+ self.background_fetch_training_state = Thread(
|
|
|
+ name=f'{self.__class__.__name__}.state_updater', daemon=True, target=self._fetch_training_state_periodically)
|
|
|
+ self.background_fetch_training_state.start()
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
|
if not self.is_synchronized:
|
|
|
@@ -95,11 +95,11 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
return
|
|
|
|
|
|
with self.lock_scheduler_params:
|
|
|
- if self.local_epoch < self.decentralized_state.max_epoch:
|
|
|
+ if self.local_epoch < self.training_state.max_epoch:
|
|
|
self.local_step = 0
|
|
|
- self.local_epoch = self.decentralized_state.max_epoch
|
|
|
+ self.local_epoch = self.training_state.max_epoch
|
|
|
|
|
|
- if self.decentralized_state.total_steps >= self.total_steps_in_epoch:
|
|
|
+ if self.training_state.total_steps >= self.total_steps_in_epoch:
|
|
|
self.local_epoch += 1
|
|
|
self.local_step = 0
|
|
|
|
|
|
@@ -114,7 +114,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
if self.local_step % self.averaging_step_period == 0:
|
|
|
self.update_event.set()
|
|
|
self.report_progress_event.set()
|
|
|
- self.fetch_decentralized_state_event.set()
|
|
|
+ self.fetch_training_state_event.set()
|
|
|
|
|
|
return step_result
|
|
|
|
|
|
@@ -136,7 +136,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.zero_grad()
|
|
|
self.averager.load_state_from_peers(**kwargs)
|
|
|
self.local_step = 0
|
|
|
- self.local_epoch = self.decentralized_state.max_epoch
|
|
|
+ self.local_epoch = self.training_state.max_epoch
|
|
|
|
|
|
@staticmethod
|
|
|
@torch.no_grad()
|
|
|
@@ -169,7 +169,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
logger.error(f"Averaging round failed: caught {e}.")
|
|
|
|
|
|
def is_synchronized(self) -> bool:
|
|
|
- return self.local_epoch + self.max_allowed_epoch_difference >= self.decentralized_state.max_epoch
|
|
|
+ return self.local_epoch + self.max_allowed_epoch_difference >= self.training_state.max_epoch
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _report_progress(self):
|
|
|
@@ -185,23 +185,23 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
expiration_time=current_time + self.report_progress_expiration, return_future=False)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
- def _fetch_decentralized_state_periodically(self):
|
|
|
+ def _fetch_training_state_periodically(self):
|
|
|
""" Read decentralized state loop """
|
|
|
while not self.stop_event.is_set():
|
|
|
- self.fetch_decentralized_state_event.wait()
|
|
|
- self.fetch_decentralized_state_event.clear()
|
|
|
+ self.fetch_training_state_event.wait()
|
|
|
+ self.fetch_training_state_event.clear()
|
|
|
if self.stop_event.is_set():
|
|
|
break
|
|
|
- self._fetch_decentralized_state()
|
|
|
+ self._fetch_training_state()
|
|
|
|
|
|
@torch.no_grad()
|
|
|
- def _fetch_decentralized_state(self):
|
|
|
+ def _fetch_training_state(self):
|
|
|
""" Read decentralized state reported by peers """
|
|
|
response, _expiration = self.dht.get(self.report_progress_key, latest=True) or (None, -float('inf'))
|
|
|
if not isinstance(response, dict) or len(response) == 0:
|
|
|
logger.info(f"Found no active peers: {response}")
|
|
|
with self.lock_scheduler_params:
|
|
|
- self.decentralized_state = DecentralizedState(max_epoch=self.local_epoch, total_steps=self.local_step)
|
|
|
+ self.training_state = TrainingState(max_epoch=self.local_epoch, total_steps=self.local_step)
|
|
|
return
|
|
|
|
|
|
valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration)]
|
|
|
@@ -214,7 +214,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
for step, time, epoch in valid_peer_states:
|
|
|
if epoch == global_epoch:
|
|
|
total_steps += step
|
|
|
- self.decentralized_state = DecentralizedState(max_epoch=global_epoch, total_steps=total_steps)
|
|
|
+ self.training_state = TrainingState(max_epoch=global_epoch, total_steps=total_steps)
|
|
|
|
|
|
|
|
|
class DecentralizedSGD(DecentralizedOptimizer):
|