Răsfoiți Sursa

debug delay_grad_averaging into submission

justheuristic 3 ani în urmă
părinte
comite
6ead2dce3d

+ 1 - 1
benchmarks/benchmark_optimizer.py

@@ -25,7 +25,7 @@ class TrainingArguments:
 
     num_peers: int = 8
     num_clients: int = 3
-    target_batch_size: int = 128
+    target_batch_size: int = 256
     reuse_grad_buffers: bool = True
     delay_grad_averaging: bool = True
     delay_optimizer_step: bool = True

+ 10 - 8
hivemind/optim/experimental/grad_averager.py

@@ -207,17 +207,19 @@ class GradientAverager(DecentralizedAverager):
 
     @contextlib.contextmanager
     @torch.no_grad()
-    def use_averaged_gradients(self, replace_model_gradients: bool = True):
+    def use_averaged_gradients(self):
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
             assert len(averaged_grads) == len(self.parameters)
             try:
-                if replace_model_gradients:
-                    old_grads = [param.grad for param in self.parameters]
-                    for param, new_grad in zip(self.parameters, averaged_grads):
-                        param.grad = new_grad
+                old_grads = [param.grad for param in self.parameters]
+                for param, new_grad in zip(self.parameters, averaged_grads):
+                    param.grad = new_grad
                 yield averaged_grads
             finally:
-                if replace_model_gradients:
-                    for param, old_grad in zip(self.parameters, old_grads):
-                        param.grad = old_grad
+                for param, old_grad in zip(self.parameters, old_grads):
+                    param.grad = old_grad
+
+    def notify_used_averaged_gradients(self):
+        """Notify averager that the results of a previous averaging round are accounted for"""
+        self._new_averaged_grads = False

+ 28 - 21
hivemind/optim/experimental/optimizer.py

@@ -123,6 +123,7 @@ class Optimizer(torch.optim.Optimizer):
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_round: Optional[StepControl] = None
+        self.previous_round: Optional[StepControl] = None
 
         self.state_averager = self._make_state_averager(
             optimizer=optimizer, params=params, scheduler=scheduler, **averager_opts or {}
@@ -263,6 +264,10 @@ class Optimizer(torch.optim.Optimizer):
 
         if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
             if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
+                if self.delay_grad_averaging:
+                    # wait for previous averaging to finish before starting a new one
+                    self.state_averager.step(wait_for_delayed_update=True)
+
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
@@ -274,7 +279,11 @@ class Optimizer(torch.optim.Optimizer):
         if not self.tracker.ready_to_update_epoch:
             return loss
 
+        assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+
         with self.tracker.pause_updates():
+            # note: we do not need to replace grads because we explicitly load grads into the optimizer
+
             logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
 
             if grad_scaler is not None:
@@ -287,7 +296,6 @@ class Optimizer(torch.optim.Optimizer):
 
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
             began_averaging_gradients = False
-
             if swarm_not_empty:
                 try:
                     self.scheduled_round = self.grad_averager.step(
@@ -306,27 +314,24 @@ class Optimizer(torch.optim.Optimizer):
             if not self.delay_grad_averaging:
                 self._average_gradients_and_load_into_optimizer(self.scheduled_round)
 
-            assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
-            with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
-                # note: we do not need to replace because the offloaded optimizer is already using averaged grads
-                next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
-
-                self.state_averager.step(
-                    increment_epoch=True,
-                    optimizer_step=not self.auxiliary,
-                    delay_optimizer_step=self.delay_optimizer_step,
-                    averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
-                    delay_averaging=not self.auxiliary,
-                    grad_scaler=grad_scaler,
-                    wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
-                    if self.delay_grad_averaging
-                    else None,
-                    averaging_opts=dict(
-                        scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
-                    )
-                    if swarm_not_empty
-                    else None,
+            next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+
+            self.state_averager.step(
+                increment_epoch=True,
+                optimizer_step=not self.auxiliary,
+                delay_optimizer_step=self.delay_optimizer_step,
+                averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
+                delay_averaging=not self.auxiliary,
+                grad_scaler=grad_scaler,
+                wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
+                if self.delay_grad_averaging
+                else None,
+                averaging_opts=dict(
+                    scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
                 )
+                if swarm_not_empty
+                else None,
+            )
 
             if not self.auxiliary:
                 self.grad_averager.reset_accumulated_grads_()
@@ -355,6 +360,8 @@ class Optimizer(torch.optim.Optimizer):
             logger.log(self.status_loglevel, f"Proceeding with local gradients")
             self.grad_averager.load_accumulators_into_averager_()
 
+        self.grad_averager.notify_used_averaged_gradients()
+
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
         if self.grad_averager.reuse_grad_buffers:

+ 1 - 0
hivemind/optim/experimental/state_averager.py

@@ -322,6 +322,7 @@ class TrainingStateAverager(DecentralizedAverager):
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
         if wait_for_trigger is not None:
+            assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
             if not (self.reuse_tensors or self.custom_gradients):
                 # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
                 # should be used for optimizer step (e.g. the gradients that were present during the call to .step or