Prechádzať zdrojové kódy

hopefully final update

justheuristic 3 rokov pred
rodič
commit
9bdf6c71b0
1 zmenil súbory, kde vykonal 11 pridanie a 10 odobranie
  1. 11 10
      hivemind/optim/experimental/optimizer.py

+ 11 - 10
hivemind/optim/experimental/optimizer.py

@@ -137,7 +137,7 @@ class Optimizer(torch.optim.Optimizer):
         params: Optional[Union[Parameters, ParamGroups]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         matchmaking_time: Optional[float] = 15.0,
-        averaging_timeout: Optional[float] = 300.0,
+        averaging_timeout: Optional[float] = 60.0,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
@@ -397,7 +397,12 @@ class Optimizer(torch.optim.Optimizer):
             should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
             should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
 
-            if should_average_state:
+            if should_average_state and self.scheduled_state is not None:
+                if self.scheduled_state.triggered or self.scheduled_state.done():
+                    logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
+                                                     f"was already used elsewhere: {self.scheduled_state}")
+                    self.scheduled_state = None
+
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
 
             self.state_averager.step(
@@ -412,10 +417,6 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
 
-            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
-                self.scheduled_state.cancel()
-                self.scheduled_state = None
-
             self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
             self._should_check_synchronization_on_update = True
             # the above line ensures that peers check for *strict* synchronization once per epoch
@@ -437,7 +438,8 @@ class Optimizer(torch.optim.Optimizer):
                 assert grad_scaler.unscale_(self)
 
         if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
-            logger.debug(self.status_loglevel, f"Discarding previous matchmaking results: {self.scheduled_grads}")
+            logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
+                                             f"was already used elsewhere: {self.scheduled_state}")
             self.scheduled_grads = None
 
         began_averaging_gradients = False
@@ -614,17 +616,16 @@ class Optimizer(torch.optim.Optimizer):
                 if not scheduled_round.triggered:
                     scheduled_round.weight = 0
                     scheduled_round.allow_allreduce()
-                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
-                    scheduled_round.cancel()
         for scheduled_round in self.scheduled_grads, self.scheduled_state:
             if scheduled_round is not None and not scheduled_round.done():
                 try:
                     time_to_deadline = scheduled_round.deadline - get_dht_time()
-                    scheduled_round.result(timeout=max(0.0, min(time_to_deadline, self.shutdown_timeout)))
+                    scheduled_round.result(timeout=max(0.0, time_to_deadline))
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
                 if not scheduled_round.done():
                     scheduled_round.cancel()
+        self.scheduled_grads = self.scheduled_state = None
 
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()