Prechádzať zdrojové kódy

handle cancellation before trigger

justheuristic 3 rokov pred
rodič
commit
d655b746dc

+ 20 - 5
hivemind/averaging/averager.py

@@ -35,6 +35,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
+    afirst,
     aiter_with_timeout,
     aiter_with_timeout,
     anext,
     anext,
     as_aiter,
     as_aiter,
@@ -413,11 +414,24 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             step.attach(trigger, cancel)
             step.attach(trigger, cancel)
             future_for_init.set_result((trigger, cancel))
             future_for_init.set_result((trigger, cancel))
 
 
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                try:
+                    if not step.triggered:
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+                        await step.wait_for_trigger()
+                    return group_info
+                except asyncio.CancelledError:
+                    return asyncio.wait(
+                        self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
+                        for peer_id in group_info.peer_ids
+                    )
+
             while not step.done():
             while not step.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
 
 
                     await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
                     await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
@@ -428,13 +442,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         check_cancel_task.cancel()
                         check_cancel_task.cancel()
 
 
                     group_info = await matchmaking_task
                     group_info = await matchmaking_task
+
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
 
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
-
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
 
 
                     step.set_result(
                     step.set_result(
@@ -478,6 +489,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                     )
                 )
                 )
 
 
+    async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
+        error = averaging_pb2.AveragingData(group_id=group_id, code=code)
+        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:

+ 25 - 6
hivemind/optim/experimental/optimizer.py

@@ -155,6 +155,7 @@ class Optimizer(torch.optim.Optimizer):
         extra_tensors: Sequence[torch.Tensor] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
+        preschedule_state_averaging: bool = False,
         performance_ema_alpha: float = 0.1,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
         verbose: bool = False,
         verbose: bool = False,
@@ -189,6 +190,8 @@ class Optimizer(torch.optim.Optimizer):
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+        self.preschedule_state_averaging = preschedule_state_averaging
+
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
 
 
@@ -347,7 +350,8 @@ class Optimizer(torch.optim.Optimizer):
                     return loss  # local gradients were reset due to overflow, must start over
                     return loss  # local gradients were reset due to overflow, must start over
 
 
             self._maybe_schedule_gradient_averaging()
             self._maybe_schedule_gradient_averaging()
-            self._maybe_schedule_state_averaging()
+            if self.preschedule_state_averaging:
+                self._maybe_schedule_state_averaging()
 
 
         else:
         else:
             # use_local_updates=True: update parameters on every step independently of other peers
             # use_local_updates=True: update parameters on every step independently of other peers
@@ -358,7 +362,8 @@ class Optimizer(torch.optim.Optimizer):
 
 
                 new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
                 new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
                 self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
                 self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
-                self._maybe_schedule_state_averaging()
+                if self.preschedule_state_averaging:
+                    self._maybe_schedule_state_averaging()
 
 
                 self.state_averager.step(
                 self.state_averager.step(
                     increment_epoch=False,
                     increment_epoch=False,
@@ -399,8 +404,11 @@ class Optimizer(torch.optim.Optimizer):
 
 
             if should_average_state and self.scheduled_state is not None:
             if should_average_state and self.scheduled_state is not None:
                 if self.scheduled_state.triggered or self.scheduled_state.done():
                 if self.scheduled_state.triggered or self.scheduled_state.done():
-                    logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
-                                                     f"was already used elsewhere: {self.scheduled_state}")
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
                     self.scheduled_state = None
                     self.scheduled_state = None
 
 
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
@@ -417,6 +425,10 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
             )
 
 
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
             self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
             self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
             self._should_check_synchronization_on_update = True
             self._should_check_synchronization_on_update = True
             # the above line ensures that peers check for *strict* synchronization once per epoch
             # the above line ensures that peers check for *strict* synchronization once per epoch
@@ -439,8 +451,11 @@ class Optimizer(torch.optim.Optimizer):
 
 
         began_averaging_gradients = False
         began_averaging_gradients = False
         if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
         if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
-            logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
-                                             f"was already used elsewhere: {self.scheduled_state}")
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
             self.scheduled_grads = None
             self.scheduled_grads = None
 
 
         elif self.tracker.global_progress.num_peers > 1:
         elif self.tracker.global_progress.num_peers > 1:
@@ -487,6 +502,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def _maybe_schedule_state_averaging(self) -> None:
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        assert self.preschedule_state_averaging
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
             return  # averaging is not performed at this epoch
@@ -582,6 +598,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         """Attempt to fetch the newest collaboration state from other peers"""
+        self._finish_background_averaging()
 
 
         with self.tracker.pause_updates():
         with self.tracker.pause_updates():
             while True:
             while True:
@@ -611,6 +628,8 @@ class Optimizer(torch.optim.Optimizer):
     def _finish_background_averaging(self):
     def _finish_background_averaging(self):
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
             if scheduled_round is not None:
             if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
                 if not scheduled_round.triggered:
                 if not scheduled_round.triggered:
                     scheduled_round.weight = 0
                     scheduled_round.weight = 0
                     scheduled_round.allow_allreduce()
                     scheduled_round.allow_allreduce()

+ 1 - 1
hivemind/optim/experimental/progress_tracker.py

@@ -83,7 +83,7 @@ class ProgressTracker(threading.Thread):
         *,
         *,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
+        max_refresh_period: float = 10,
         default_refresh_period: float = 3,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         expected_drift_rate: float = 0.2,

+ 1 - 3
hivemind/optim/experimental/state_averager.py

@@ -559,9 +559,7 @@ class TrainingStateAverager(DecentralizedAverager):
         """Copy parameters from offloaded optimizer to the main model"""
         """Copy parameters from offloaded optimizer to the main model"""
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
         offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-        assert len(offloaded_parameters) == len(
-            self.main_parameters
-        ), "Optimizer parameters changed during training"
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
         for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
         for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
             main_param.copy_(offloaded_param, non_blocking=True)
             main_param.copy_(offloaded_param, non_blocking=True)