Explorar el Código

hopefully fix freezing, add averaging frequency

justheuristic hace 3 años
padre
commit
21a2a57edb
Se han modificado 1 ficheros con 5 adiciones y 4 borrados
  1. 5 4
      hivemind/optim/experimental/optimizer.py

+ 5 - 4
hivemind/optim/experimental/optimizer.py

@@ -256,8 +256,8 @@ class Optimizer(torch.optim.Optimizer):
                 logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
                 self.scheduled_round = None
 
-            need_averaging = self.tracker.global_progress.num_peers > 1
-            if need_averaging:
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            if swarm_not_empty:
                 try:
                     group_info = self.grad_averager.step(
                         control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
@@ -276,18 +276,19 @@ class Optimizer(torch.optim.Optimizer):
             assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
             with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
                 # note: we do not need to replace because the offloaded optimizer is already using averaged grads
+                next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
 
                 self.state_averager.step(
                     increment_epoch=True,
                     optimizer_step=True,
                     delay_optimizer_step=self.delay_optimizer_step,
                     grad_scaler=grad_scaler,
-                    averaging_round=need_averaging and self.tracker.global_epoch % self.average_state_every == 0,
+                    averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
                     delay_averaging=True,
                     averaging_opts=dict(
                         scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
                     )
-                    if need_averaging
+                    if swarm_not_empty
                     else None,
                 )