소스 검색

[previous commit is verified as stable] implement delayed averaging

justheuristic 3 년 전
부모
커밋
ec05f8b358
2개의 변경된 파일47개의 추가작업 그리고 13개의 파일을 삭제
  1. 42 13
      hivemind/optim/experimental/optimizer.py
  2. 5 0
      hivemind/optim/experimental/state_averager.py

+ 42 - 13
hivemind/optim/experimental/optimizer.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import logging
 import os
+from functools import partial
 from typing import Callable, Optional, Union
 
 import torch
@@ -98,7 +99,8 @@ class Optimizer(torch.optim.Optimizer):
         load_state_timeout: float = 600.0,
         average_state_every: int = 1,
         reuse_grad_buffers: bool = False,
-        delay_optimizer_step: bool = False,
+        delay_grad_averaging: bool = False,
+        delay_optimizer_step: Optional[bool] = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         averager_opts: Optional[dict] = None,
@@ -107,13 +109,15 @@ class Optimizer(torch.optim.Optimizer):
         verbose: bool = False,
     ):
         client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
 
         self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
-        self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
-        self.average_state_every = average_state_every
+        self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
 
@@ -282,21 +286,25 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_round = None
 
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            began_averaging = False
+
             if swarm_not_empty:
                 try:
-                    group_info = self.grad_averager.step(
-                        control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
+                    self.scheduled_round = self.grad_averager.step(
+                        control=self.scheduled_round, reset_accumulators=True, wait=False
                     )
-                    logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                    assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
+                    began_averaging = True
                 except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}, using local grads")
-                    self.grad_averager.load_accumulators_into_averager_()
+                    logger.exception(e)
 
-            else:
-                if self.scheduled_round is not None and not self.scheduled_round.done():
-                    self.scheduled_round.cancel()
-                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
-                self.grad_averager.load_accumulators_into_averager_()
+            if not began_averaging and self.scheduled_round is not None and not self.scheduled_round.done():
+                logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
+                self.scheduled_round.cancel()
+                self.scheduled_round = None
+
+            if not self.delay_grad_averaging:
+                self._average_gradients_and_load_into_optimizer(self.scheduled_round)
 
             assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
             with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
@@ -310,6 +318,8 @@ class Optimizer(torch.optim.Optimizer):
                     averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
                     delay_averaging=not self.auxiliary,
                     grad_scaler=grad_scaler,
+                    wait_for_trigger=partial(
+                        self._average_gradients_and_load_into_optimizer, self.scheduled_round) if self.delay_grad_averaging else None,
                     averaging_opts=dict(
                         scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
                     )
@@ -325,6 +335,25 @@ class Optimizer(torch.optim.Optimizer):
             logger.log(self.status_loglevel, f"Optimizer step done! Transitioning to epoch {self.local_epoch}.")
         return loss
 
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            logger.log(self.status_loglevel, f"Proceeding with local gradients")
+            self.grad_averager.load_accumulators_into_averager_()
+
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
         if self.grad_averager.reuse_grad_buffers:

+ 5 - 0
hivemind/optim/experimental/state_averager.py

@@ -409,9 +409,12 @@ class TrainingStateAverager(DecentralizedAverager):
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         """
+        began_running = False
         try:
             if wait_for_trigger is not None:
                 wait_for_trigger()
+            began_running = True
+
             if optimizer_step:
                 with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
                     logger.log(self.status_loglevel, f"Running optimizer step")
@@ -455,6 +458,8 @@ class TrainingStateAverager(DecentralizedAverager):
                         self._update_scheduler()
 
         except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception.")
             logger.exception(e)
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()