Browse Source

[previous commit is verified as stable] implement delayed averaging

justheuristic 3 years ago
parent
commit
ec05f8b358

+ 42 - 13
hivemind/optim/experimental/optimizer.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 
 import logging
 import logging
 import os
 import os
+from functools import partial
 from typing import Callable, Optional, Union
 from typing import Callable, Optional, Union
 
 
 import torch
 import torch
@@ -98,7 +99,8 @@ class Optimizer(torch.optim.Optimizer):
         load_state_timeout: float = 600.0,
         load_state_timeout: float = 600.0,
         average_state_every: int = 1,
         average_state_every: int = 1,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
-        delay_optimizer_step: bool = False,
+        delay_grad_averaging: bool = False,
+        delay_optimizer_step: Optional[bool] = None,
         client_mode: bool = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         auxiliary: bool = False,
         averager_opts: Optional[dict] = None,
         averager_opts: Optional[dict] = None,
@@ -107,13 +109,15 @@ class Optimizer(torch.optim.Optimizer):
         verbose: bool = False,
         verbose: bool = False,
     ):
     ):
         client_mode = client_mode if client_mode is None else dht.client_mode
         client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
 
 
         self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
         self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
-        self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
-        self.average_state_every = average_state_every
+        self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
 
 
@@ -282,21 +286,25 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_round = None
                 self.scheduled_round = None
 
 
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            began_averaging = False
+
             if swarm_not_empty:
             if swarm_not_empty:
                 try:
                 try:
-                    group_info = self.grad_averager.step(
-                        control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
+                    self.scheduled_round = self.grad_averager.step(
+                        control=self.scheduled_round, reset_accumulators=True, wait=False
                     )
                     )
-                    logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                    assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
+                    began_averaging = True
                 except BaseException as e:
                 except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}, using local grads")
-                    self.grad_averager.load_accumulators_into_averager_()
+                    logger.exception(e)
 
 
-            else:
-                if self.scheduled_round is not None and not self.scheduled_round.done():
-                    self.scheduled_round.cancel()
-                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
-                self.grad_averager.load_accumulators_into_averager_()
+            if not began_averaging and self.scheduled_round is not None and not self.scheduled_round.done():
+                logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
+                self.scheduled_round.cancel()
+                self.scheduled_round = None
+
+            if not self.delay_grad_averaging:
+                self._average_gradients_and_load_into_optimizer(self.scheduled_round)
 
 
             assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
             assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
             with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
             with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
@@ -310,6 +318,8 @@ class Optimizer(torch.optim.Optimizer):
                     averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
                     averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
                     delay_averaging=not self.auxiliary,
                     delay_averaging=not self.auxiliary,
                     grad_scaler=grad_scaler,
                     grad_scaler=grad_scaler,
+                    wait_for_trigger=partial(
+                        self._average_gradients_and_load_into_optimizer, self.scheduled_round) if self.delay_grad_averaging else None,
                     averaging_opts=dict(
                     averaging_opts=dict(
                         scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
                         scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
                     )
                     )
@@ -325,6 +335,25 @@ class Optimizer(torch.optim.Optimizer):
             logger.log(self.status_loglevel, f"Optimizer step done! Transitioning to epoch {self.local_epoch}.")
             logger.log(self.status_loglevel, f"Optimizer step done! Transitioning to epoch {self.local_epoch}.")
         return loss
         return loss
 
 
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            logger.log(self.status_loglevel, f"Proceeding with local gradients")
+            self.grad_averager.load_accumulators_into_averager_()
+
     def zero_grad(self, set_to_none: bool = False):
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
         if self.grad_averager.reuse_grad_buffers:
         if self.grad_averager.reuse_grad_buffers:

+ 5 - 0
hivemind/optim/experimental/state_averager.py

@@ -409,9 +409,12 @@ class TrainingStateAverager(DecentralizedAverager):
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         This method is meant to be called in the background executor.
         """
         """
+        began_running = False
         try:
         try:
             if wait_for_trigger is not None:
             if wait_for_trigger is not None:
                 wait_for_trigger()
                 wait_for_trigger()
+            began_running = True
+
             if optimizer_step:
             if optimizer_step:
                 with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
                 with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
                     logger.log(self.status_loglevel, f"Running optimizer step")
                     logger.log(self.status_loglevel, f"Running optimizer step")
@@ -455,6 +458,8 @@ class TrainingStateAverager(DecentralizedAverager):
                         self._update_scheduler()
                         self._update_scheduler()
 
 
         except Exception as e:
         except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception.")
             logger.exception(e)
             logger.exception(e)
             self.finished_optimizer_step.set()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()
             self.finished_averaging_round.set()