Explorar o código

[benchmarking finished] avoid accidentally cancelling averaging rounds

justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
b2be06c36e

+ 4 - 2
hivemind/averaging/control.py

@@ -147,8 +147,10 @@ class StepControl(MPFuture):
 
     def __del__(self):
         if os.getpid() == self._origin_pid and not self.triggered:
-            logger.warning("Deleted an averaging StepControl, but the step was not triggered. This may cause other "
-                           "peers to fail an averaging round via TimeoutError.")
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
         super().__del__()
 
     def cancel(self) -> bool:

+ 23 - 20
hivemind/optim/experimental/optimizer.py

@@ -8,7 +8,7 @@ from typing import Callable, Optional, Sequence, Union
 
 import torch
 
-from hivemind.averaging.control import StepControl, AveragingStage
+from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.optim.experimental.grad_averager import GradientAverager
@@ -234,6 +234,7 @@ class Optimizer(torch.optim.Optimizer):
         return TrainingStateAverager(
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
@@ -251,6 +252,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
@@ -409,6 +411,7 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_control=self.scheduled_state if should_average_state else None,
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
+
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
                 self.scheduled_state.cancel()
                 self.scheduled_state = None
@@ -443,7 +446,6 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_grads = self.grad_averager.step(
                     control=self.scheduled_grads, reset_accumulators=True, wait=False
                 )
-                assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
                 began_averaging_gradients = True
             except BaseException as e:
                 logger.exception(e)
@@ -479,10 +481,7 @@ class Optimizer(torch.optim.Optimizer):
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
-                scheduled_time = self.tracker.estimated_next_update_time
-                if self.client_mode:
-                    scheduled_time = get_dht_time() + self.averaging_timeout
-                self.scheduled_grads = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
@@ -496,14 +495,16 @@ class Optimizer(torch.optim.Optimizer):
         eta_seconds_to_averaging = estimated_time - get_dht_time()
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.delay_state_averaging:
+                # wait for previous averaging to finish before starting a new one
+                self.state_averager.step(wait_for_delayed_updates=True)
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+
                 min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
                 actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
                 logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
-                if self.client_mode:
-                    estimated_time = get_dht_time() + self.averaging_timeout
                 self.scheduled_state = self.state_averager.schedule_step(
-                    estimated_time, gather=next_epoch, timeout=self.averaging_timeout
+                    gather=next_epoch, timeout=self.averaging_timeout
                 )
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
@@ -582,7 +583,8 @@ class Optimizer(torch.optim.Optimizer):
 
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
-        self._finish_scheduled_averaging()
+        self.state_averager.step(wait_for_delayed_updates=True)
+        self._finish_background_averaging()
 
         with self.tracker.pause_updates():
             while True:
@@ -609,22 +611,23 @@ class Optimizer(torch.optim.Optimizer):
                 if not self.client_mode:
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
-    def _finish_scheduled_averaging(self):
+    def _finish_background_averaging(self):
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
             if scheduled_round is not None:
-                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
-                    scheduled_round.cancel()
-                if scheduled_round.stage == AveragingStage.AWAITING_TRIGGER:
+                if not scheduled_round.triggered:
                     scheduled_round.weight = 0
                     scheduled_round.allow_allreduce()
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None:
+            if scheduled_round is not None and not scheduled_round.done():
                 try:
-                    scheduled_round.result(timeout=max(0.0, scheduled_round.deadline - get_dht_time()))
+                    time_to_deadline = scheduled_round.deadline - get_dht_time()
+                    scheduled_round.result(timeout=max(0.0, min(time_to_deadline, self.shutdown_timeout)))
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
-            if not scheduled_round.done():
-                scheduled_round.cancel()
+                if not scheduled_round.done():
+                    scheduled_round.cancel()
 
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
@@ -667,10 +670,10 @@ class Optimizer(torch.optim.Optimizer):
 
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")
-        self._finish_scheduled_averaging()
         self.tracker.shutdown(self.shutdown_timeout)
-        logger.debug("Shutting down averagers...")
         self.state_averager.step(wait_for_delayed_updates=True)
+        self._finish_background_averaging()
+        logger.debug("Shutting down averagers...")
         self.state_averager.shutdown()
         if self.use_gradient_averaging:
             self.grad_averager.shutdown()

+ 0 - 1
hivemind/optim/experimental/state_averager.py

@@ -476,7 +476,6 @@ class TrainingStateAverager(DecentralizedAverager):
         try:
             if averaging_round and averaging_control is None:
                 averaging_control = super().step(
-                    scheduled_time=get_dht_time() + self.delay_before_averaging.ema_seconds_per_sample,
                     gather=self.local_epoch,
                     require_trigger=True,
                     timeout=timeout,