Jelajahi Sumber

Merge branch 'check' into Checkpointer2

Michael Diskin 3 tahun lalu
induk
melakukan
6f89bcf439

+ 1 - 1
examples/albert/arguments.py

@@ -49,7 +49,7 @@ class AveragerArguments:
         default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     averaging_timeout: float = field(
-        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}

+ 1 - 1
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 19
+DEFAULT_PART_SIZE_BYTES = 2 ** 20
 
 
 class TensorPartContainer:

+ 39 - 19
hivemind/optim/collaborative.py

@@ -149,7 +149,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -181,6 +181,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @property
     def is_synchronized(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step
+
+    @property
+    def is_within_tolerance(self) -> bool:
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
 
     def is_alive(self) -> bool:
@@ -197,7 +201,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
 
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
@@ -226,10 +230,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.batch_size_per_step = batch_size
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
-        if not self.is_synchronized:
+        if not self.is_synchronized and not self.is_within_tolerance:
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             return
+        elif not self.is_synchronized and self.is_within_tolerance:
+            self.averager.local_step = self.collaboration_state.optimizer_step
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
@@ -241,7 +248,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
-            self.local_steps_accumulated += 1
+            self.local_updates_accumulated += 1
             self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
 
@@ -249,25 +256,31 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
 
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        if not self.is_synchronized:
-            self.load_state_from_peers()
-            return
-
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            self.collaboration_state = self._fetch_state()
+            self.collaboration_state_updated.set()
+
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 try:
-                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                    group_info = self.averager.step(
+                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
+                    )
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                        # update our current step if we averaged with another peer that was at a more recent step
+                        for peer, peer_step in group_info.items():
+                            if isinstance(peer_step, int):
+                                current_step = max(current_step, peer_step)
+                            else:
+                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
+
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -279,7 +292,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             self.opt.step()
             self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
@@ -304,12 +317,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             current_step, group_info = self.averager.local_step, None
+
             try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
                 if group_info:
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                    # update our current step if we averaged with another peer that was at a more recent step
+                    for peer, peer_step in group_info.items():
+                        if isinstance(peer_step, int):
+                            current_step = max(current_step, peer_step)
+                        else:
+                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -412,9 +432,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            local_eta_next_step = (
-                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
-            )
+            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
+            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
+
             return CollaborationState(
                 self.local_step,
                 self.local_samples_accumulated,