Bladeren bron

handle cancellation before trigger

justheuristic 3 jaren geleden
bovenliggende
commit
d655b746dc

+ 20 - 5
hivemind/averaging/averager.py

@@ -35,6 +35,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
     achain,
+    afirst,
     aiter_with_timeout,
     anext,
     as_aiter,
@@ -413,11 +414,24 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             step.attach(trigger, cancel)
             future_for_init.set_result((trigger, cancel))
 
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                try:
+                    if not step.triggered:
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+                        await step.wait_for_trigger()
+                    return group_info
+                except asyncio.CancelledError:
+                    return asyncio.wait(
+                        self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
+                        for peer_id in group_info.peer_ids
+                    )
+
             while not step.done():
                 try:
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
 
                     await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
@@ -428,13 +442,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         check_cancel_task.cancel()
 
                     group_info = await matchmaking_task
+
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
-
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
 
                     step.set_result(
@@ -478,6 +489,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
+    async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
+        error = averaging_pb2.AveragingData(group_id=group_id, code=code)
+        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:

+ 25 - 6
hivemind/optim/experimental/optimizer.py

@@ -155,6 +155,7 @@ class Optimizer(torch.optim.Optimizer):
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
+        preschedule_state_averaging: bool = False,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
         verbose: bool = False,
@@ -189,6 +190,8 @@ class Optimizer(torch.optim.Optimizer):
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+        self.preschedule_state_averaging = preschedule_state_averaging
+
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
 
@@ -347,7 +350,8 @@ class Optimizer(torch.optim.Optimizer):
                     return loss  # local gradients were reset due to overflow, must start over
 
             self._maybe_schedule_gradient_averaging()
-            self._maybe_schedule_state_averaging()
+            if self.preschedule_state_averaging:
+                self._maybe_schedule_state_averaging()
 
         else:
             # use_local_updates=True: update parameters on every step independently of other peers
@@ -358,7 +362,8 @@ class Optimizer(torch.optim.Optimizer):
 
                 new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
                 self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
-                self._maybe_schedule_state_averaging()
+                if self.preschedule_state_averaging:
+                    self._maybe_schedule_state_averaging()
 
                 self.state_averager.step(
                     increment_epoch=False,
@@ -399,8 +404,11 @@ class Optimizer(torch.optim.Optimizer):
 
             if should_average_state and self.scheduled_state is not None:
                 if self.scheduled_state.triggered or self.scheduled_state.done():
-                    logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
-                                                     f"was already used elsewhere: {self.scheduled_state}")
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
                     self.scheduled_state = None
 
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
@@ -417,6 +425,10 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
 
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
             self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
             self._should_check_synchronization_on_update = True
             # the above line ensures that peers check for *strict* synchronization once per epoch
@@ -439,8 +451,11 @@ class Optimizer(torch.optim.Optimizer):
 
         began_averaging_gradients = False
         if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
-            logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
-                                             f"was already used elsewhere: {self.scheduled_state}")
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
             self.scheduled_grads = None
 
         elif self.tracker.global_progress.num_peers > 1:
@@ -487,6 +502,7 @@ class Optimizer(torch.optim.Optimizer):
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        assert self.preschedule_state_averaging
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
@@ -582,6 +598,7 @@ class Optimizer(torch.optim.Optimizer):
 
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
+        self._finish_background_averaging()
 
         with self.tracker.pause_updates():
             while True:
@@ -611,6 +628,8 @@ class Optimizer(torch.optim.Optimizer):
     def _finish_background_averaging(self):
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
             if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
                 if not scheduled_round.triggered:
                     scheduled_round.weight = 0
                     scheduled_round.allow_allreduce()

+ 1 - 1
hivemind/optim/experimental/progress_tracker.py

@@ -83,7 +83,7 @@ class ProgressTracker(threading.Thread):
         *,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
+        max_refresh_period: float = 10,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,

+ 1 - 3
hivemind/optim/experimental/state_averager.py

@@ -559,9 +559,7 @@ class TrainingStateAverager(DecentralizedAverager):
         """Copy parameters from offloaded optimizer to the main model"""
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-        assert len(offloaded_parameters) == len(
-            self.main_parameters
-        ), "Optimizer parameters changed during training"
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
         for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
             main_param.copy_(offloaded_param, non_blocking=True)