|
@@ -149,7 +149,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
self.training_progress_key = f"{self.prefix}_progress"
|
|
|
self.local_samples_accumulated = 0 # a number of local samples accumulated since last optimizer update
|
|
|
- self.local_steps_accumulated = 0 # a number of calls to step() since last optimizer update
|
|
|
+ self.local_updates_accumulated = 0 # a number of calls to step() since last optimizer update
|
|
|
self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
|
|
|
self.last_step_time = None
|
|
|
|
|
@@ -181,6 +181,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
@property
|
|
|
def is_synchronized(self) -> bool:
|
|
|
+ return self.local_step >= self.collaboration_state.optimizer_step
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_within_tolerance(self) -> bool:
|
|
|
return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
|
|
|
|
|
|
def is_alive(self) -> bool:
|
|
@@ -197,7 +201,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
logger.exception(f"Failed to load state from peers: {e}, retrying ...")
|
|
|
continue
|
|
|
|
|
|
- self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
|
+ self.local_samples_accumulated = self.local_updates_accumulated = 0
|
|
|
self.reset_accumulated_grads_()
|
|
|
self.update_scheduler()
|
|
|
|
|
@@ -226,10 +230,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.batch_size_per_step = batch_size
|
|
|
batch_size = batch_size if batch_size is not None else self.batch_size_per_step
|
|
|
|
|
|
- if not self.is_synchronized:
|
|
|
+ if not self.is_synchronized and not self.is_within_tolerance:
|
|
|
logger.log(self.status_loglevel, "Peer is out of sync.")
|
|
|
self.load_state_from_peers()
|
|
|
return
|
|
|
+ elif not self.is_synchronized and self.is_within_tolerance:
|
|
|
+ self.averager.local_step = self.collaboration_state.optimizer_step
|
|
|
+ logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
|
|
|
|
|
|
if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
|
|
|
logger.warning(
|
|
@@ -241,7 +248,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
with self.lock_local_progress:
|
|
|
self.local_samples_accumulated += batch_size
|
|
|
- self.local_steps_accumulated += 1
|
|
|
+ self.local_updates_accumulated += 1
|
|
|
self.performance_ema.update(num_processed=batch_size)
|
|
|
self.should_report_progress.set()
|
|
|
|
|
@@ -249,25 +256,31 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
return
|
|
|
|
|
|
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
|
|
|
- self.collaboration_state = self._fetch_state()
|
|
|
- self.collaboration_state_updated.set()
|
|
|
-
|
|
|
- if not self.is_synchronized:
|
|
|
- self.load_state_from_peers()
|
|
|
- return
|
|
|
-
|
|
|
with self.performance_ema.pause(), self.lock_collaboration_state:
|
|
|
+ self.collaboration_state = self._fetch_state()
|
|
|
+ self.collaboration_state_updated.set()
|
|
|
+
|
|
|
# divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
- self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
|
|
|
+ self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
|
|
|
current_step, group_info = self.averager.local_step, None
|
|
|
|
|
|
if self.collaboration_state.num_peers > 1:
|
|
|
mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
|
|
|
weight = self.local_samples_accumulated / mean_samples_per_worker
|
|
|
try:
|
|
|
- group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
|
|
|
+ group_info = self.averager.step(
|
|
|
+ weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
|
|
|
+ )
|
|
|
if group_info:
|
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
|
+
|
|
|
+ # update our current step if we averaged with another peer that was at a more recent step
|
|
|
+ for peer, peer_step in group_info.items():
|
|
|
+ if isinstance(peer_step, int):
|
|
|
+ current_step = max(current_step, peer_step)
|
|
|
+ else:
|
|
|
+ logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
|
|
|
+
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
|
|
|
|
|
@@ -279,7 +292,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
self.opt.step()
|
|
|
self.reset_accumulated_grads_()
|
|
|
- self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
|
+ self.local_samples_accumulated = self.local_updates_accumulated = 0
|
|
|
self.collaboration_state.register_step(current_step + 1)
|
|
|
self.averager.local_step = current_step + 1
|
|
|
self.collaboration_state_updated.set()
|
|
@@ -304,12 +317,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.collaboration_state_updated.set()
|
|
|
|
|
|
with self.lock_collaboration_state:
|
|
|
- # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
current_step, group_info = self.averager.local_step, None
|
|
|
+
|
|
|
try:
|
|
|
- group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
|
|
|
+ group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
|
|
|
if group_info:
|
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
|
+
|
|
|
+ # update our current step if we averaged with another peer that was at a more recent step
|
|
|
+ for peer, peer_step in group_info.items():
|
|
|
+ if isinstance(peer_step, int):
|
|
|
+ current_step = max(current_step, peer_step)
|
|
|
+ else:
|
|
|
+ logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
|
|
|
|
|
@@ -412,9 +432,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
if not isinstance(response, dict) or len(response) == 0:
|
|
|
logger.log(self.status_loglevel, f"Found no active peers: {response}")
|
|
|
- local_eta_next_step = (
|
|
|
- max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
|
|
|
- )
|
|
|
+ samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
|
|
|
+ local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
|
|
|
+
|
|
|
return CollaborationState(
|
|
|
self.local_step,
|
|
|
self.local_samples_accumulated,
|