Browse Source

[benchmarking finished] avoid accidentally cancelling averaging rounds

justheuristic 3 năm trước cách đây
mục cha
commit
b2be06c36e

+ 4 - 2
hivemind/averaging/control.py

@@ -147,8 +147,10 @@ class StepControl(MPFuture):
 
 
     def __del__(self):
     def __del__(self):
         if os.getpid() == self._origin_pid and not self.triggered:
         if os.getpid() == self._origin_pid and not self.triggered:
-            logger.warning("Deleted an averaging StepControl, but the step was not triggered. This may cause other "
-                           "peers to fail an averaging round via TimeoutError.")
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
         super().__del__()
         super().__del__()
 
 
     def cancel(self) -> bool:
     def cancel(self) -> bool:

+ 23 - 20
hivemind/optim/experimental/optimizer.py

@@ -8,7 +8,7 @@ from typing import Callable, Optional, Sequence, Union
 
 
 import torch
 import torch
 
 
-from hivemind.averaging.control import StepControl, AveragingStage
+from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.grad_averager import GradientAverager
@@ -234,6 +234,7 @@ class Optimizer(torch.optim.Optimizer):
         return TrainingStateAverager(
         return TrainingStateAverager(
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_state_averager",
             prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.averaging_timeout,
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             offload_optimizer=self.offload_optimizer,
             offload_optimizer=self.offload_optimizer,
@@ -251,6 +252,7 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.averaging_timeout,
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
@@ -409,6 +411,7 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_control=self.scheduled_state if should_average_state else None,
                 averaging_control=self.scheduled_state if should_average_state else None,
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
             )
+
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
             if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
                 self.scheduled_state.cancel()
                 self.scheduled_state.cancel()
                 self.scheduled_state = None
                 self.scheduled_state = None
@@ -443,7 +446,6 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_grads = self.grad_averager.step(
                 self.scheduled_grads = self.grad_averager.step(
                     control=self.scheduled_grads, reset_accumulators=True, wait=False
                     control=self.scheduled_grads, reset_accumulators=True, wait=False
                 )
                 )
-                assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
                 began_averaging_gradients = True
                 began_averaging_gradients = True
             except BaseException as e:
             except BaseException as e:
                 logger.exception(e)
                 logger.exception(e)
@@ -479,10 +481,7 @@ class Optimizer(torch.optim.Optimizer):
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
-                scheduled_time = self.tracker.estimated_next_update_time
-                if self.client_mode:
-                    scheduled_time = get_dht_time() + self.averaging_timeout
-                self.scheduled_grads = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
 
 
     def _maybe_schedule_state_averaging(self) -> None:
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
@@ -496,14 +495,16 @@ class Optimizer(torch.optim.Optimizer):
         eta_seconds_to_averaging = estimated_time - get_dht_time()
         eta_seconds_to_averaging = estimated_time - get_dht_time()
 
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
         if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.delay_state_averaging:
+                # wait for previous averaging to finish before starting a new one
+                self.state_averager.step(wait_for_delayed_updates=True)
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+
                 min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
                 min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
                 actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
                 actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
                 logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
                 logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
-                if self.client_mode:
-                    estimated_time = get_dht_time() + self.averaging_timeout
                 self.scheduled_state = self.state_averager.schedule_step(
                 self.scheduled_state = self.state_averager.schedule_step(
-                    estimated_time, gather=next_epoch, timeout=self.averaging_timeout
+                    gather=next_epoch, timeout=self.averaging_timeout
                 )
                 )
 
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
@@ -582,7 +583,8 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         """Attempt to fetch the newest collaboration state from other peers"""
-        self._finish_scheduled_averaging()
+        self.state_averager.step(wait_for_delayed_updates=True)
+        self._finish_background_averaging()
 
 
         with self.tracker.pause_updates():
         with self.tracker.pause_updates():
             while True:
             while True:
@@ -609,22 +611,23 @@ class Optimizer(torch.optim.Optimizer):
                 if not self.client_mode:
                 if not self.client_mode:
                     self.grad_averager.state_sharing_priority = self.local_epoch
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
 
-    def _finish_scheduled_averaging(self):
+    def _finish_background_averaging(self):
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
             if scheduled_round is not None:
             if scheduled_round is not None:
-                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
-                    scheduled_round.cancel()
-                if scheduled_round.stage == AveragingStage.AWAITING_TRIGGER:
+                if not scheduled_round.triggered:
                     scheduled_round.weight = 0
                     scheduled_round.weight = 0
                     scheduled_round.allow_allreduce()
                     scheduled_round.allow_allreduce()
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None:
+            if scheduled_round is not None and not scheduled_round.done():
                 try:
                 try:
-                    scheduled_round.result(timeout=max(0.0, scheduled_round.deadline - get_dht_time()))
+                    time_to_deadline = scheduled_round.deadline - get_dht_time()
+                    scheduled_round.result(timeout=max(0.0, min(time_to_deadline, self.shutdown_timeout)))
                 except BaseException as e:
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
                     logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
-            if not scheduled_round.done():
-                scheduled_round.cancel()
+                if not scheduled_round.done():
+                    scheduled_round.cancel()
 
 
     def state_dict(self) -> dict:
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict = self.state_averager.optimizer.state_dict()
@@ -667,10 +670,10 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def shutdown(self):
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")
         logger.debug("Sending goodbye to peers...")
-        self._finish_scheduled_averaging()
         self.tracker.shutdown(self.shutdown_timeout)
         self.tracker.shutdown(self.shutdown_timeout)
-        logger.debug("Shutting down averagers...")
         self.state_averager.step(wait_for_delayed_updates=True)
         self.state_averager.step(wait_for_delayed_updates=True)
+        self._finish_background_averaging()
+        logger.debug("Shutting down averagers...")
         self.state_averager.shutdown()
         self.state_averager.shutdown()
         if self.use_gradient_averaging:
         if self.use_gradient_averaging:
             self.grad_averager.shutdown()
             self.grad_averager.shutdown()

+ 0 - 1
hivemind/optim/experimental/state_averager.py

@@ -476,7 +476,6 @@ class TrainingStateAverager(DecentralizedAverager):
         try:
         try:
             if averaging_round and averaging_control is None:
             if averaging_round and averaging_control is None:
                 averaging_control = super().step(
                 averaging_control = super().step(
-                    scheduled_time=get_dht_time() + self.delay_before_averaging.ema_seconds_per_sample,
                     gather=self.local_epoch,
                     gather=self.local_epoch,
                     require_trigger=True,
                     require_trigger=True,
                     timeout=timeout,
                     timeout=timeout,