Преглед изворни кода

Fix step_tolerance in CollaborativeOptimizer (#383)

If a peer detects another peer who is ahead, but within tolerance, it will jump to that peer's step.

This fixes a rare bug bug when one peer have inadvertently ended up once step ahead of others, it would cause the following cycle:

1. old peers would not load state from new ones because the new peer is still within step_tolerance
2. old peers would not contribute to training minibatches because we did not account for step tolerance there
3. new peers will maintain the 1-step distance from others because they made steps concurrently with the rest

In practice, this resulted in significant performance degradations for several (~10) steps infrequently during training, but could cause significant loss of time and excessive batch size

Other changes:
* changed default timeout from 30s to 1 minute to ensure that users will be able to run something even with a slow connection
* fixed a minor bug that caused wrong ETA to next step when collaboration consisted of exactly 1 peer

Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic пре 3 година
родитељ
комит
504a96363e
2 измењених фајлова са 40 додато и 20 уклоњено
  1. 1 1
      examples/albert/arguments.py
  2. 39 19
      hivemind/optim/collaborative.py

+ 1 - 1
examples/albert/arguments.py

@@ -49,7 +49,7 @@ class AveragerArguments:
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     )
     averaging_timeout: float = field(
     averaging_timeout: float = field(
-        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     )
     min_refresh_period: float = field(
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}

+ 39 - 19
hivemind/optim/collaborative.py

@@ -149,7 +149,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         self.training_progress_key = f"{self.prefix}_progress"
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
         self.last_step_time = None
 
 
@@ -181,6 +181,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
     @property
     @property
     def is_synchronized(self) -> bool:
     def is_synchronized(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step
+
+    @property
+    def is_within_tolerance(self) -> bool:
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
 
 
     def is_alive(self) -> bool:
     def is_alive(self) -> bool:
@@ -197,7 +201,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
                     continue
 
 
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.update_scheduler()
             self.update_scheduler()
 
 
@@ -215,10 +219,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.batch_size_per_step = batch_size
             self.batch_size_per_step = batch_size
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
 
-        if not self.is_synchronized:
+        if not self.is_synchronized and not self.is_within_tolerance:
             logger.log(self.status_loglevel, "Peer is out of sync.")
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             self.load_state_from_peers()
             return
             return
+        elif not self.is_synchronized and self.is_within_tolerance:
+            self.averager.local_step = self.collaboration_state.optimizer_step
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
             logger.warning(
@@ -230,7 +237,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
-            self.local_steps_accumulated += 1
+            self.local_updates_accumulated += 1
             self.performance_ema.update(num_processed=batch_size)
             self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
             self.should_report_progress.set()
 
 
@@ -238,25 +245,31 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
             return
 
 
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        if not self.is_synchronized:
-            self.load_state_from_peers()
-            return
-
         with self.performance_ema.pause(), self.lock_collaboration_state:
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            self.collaboration_state = self._fetch_state()
+            self.collaboration_state_updated.set()
+
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
             current_step, group_info = self.averager.local_step, None
             current_step, group_info = self.averager.local_step, None
 
 
             if self.collaboration_state.num_peers > 1:
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 try:
                 try:
-                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                    group_info = self.averager.step(
+                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
+                    )
                     if group_info:
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                        # update our current step if we averaged with another peer that was at a more recent step
+                        for peer, peer_step in group_info.items():
+                            if isinstance(peer_step, int):
+                                current_step = max(current_step, peer_step)
+                            else:
+                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
+
                 except BaseException as e:
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
 
@@ -268,7 +281,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
             self.opt.step()
             self.opt.step()
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.collaboration_state_updated.set()
@@ -293,12 +306,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.collaboration_state_updated.set()
         self.collaboration_state_updated.set()
 
 
         with self.lock_collaboration_state:
         with self.lock_collaboration_state:
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             current_step, group_info = self.averager.local_step, None
             current_step, group_info = self.averager.local_step, None
+
             try:
             try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
                 if group_info:
                 if group_info:
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                    # update our current step if we averaged with another peer that was at a more recent step
+                    for peer, peer_step in group_info.items():
+                        if isinstance(peer_step, int):
+                            current_step = max(current_step, peer_step)
+                        else:
+                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
             except BaseException as e:
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
 
@@ -401,9 +421,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
         if not isinstance(response, dict) or len(response) == 0:
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            local_eta_next_step = (
-                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
-            )
+            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
+            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
+
             return CollaborationState(
             return CollaborationState(
                 self.local_step,
                 self.local_step,
                 self.local_samples_accumulated,
                 self.local_samples_accumulated,