Przeglądaj źródła

Fix offloaded optimizer with single peer (#450)

Previously, hivemind.Optimizer was hard-wired to use the averaged gradients -- as in "averaged with peers".
If you are the only peer, gradients are not averaged, so optimizer runs with zero gradients all the time.
This PR fixes this by loading local gradients into optimizer if there are no other peers (or averaging failed for any other reason).

Co-authored-by: Xiangpeng Wan <elricwan@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 lat temu
rodzic
commit
a974b551c6
1 zmienionych plików z 13 dodań i 4 usunięć
  1. 13 4
      hivemind/optim/optimizer.py

+ 13 - 4
hivemind/optim/optimizer.py

@@ -442,7 +442,8 @@ class Optimizer(torch.optim.Optimizer):
 
                 began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
                 if not began_averaging_gradients:
-                    pass  # failed to start gradient averaging due to an internal error
+                    # failed to start gradient averaging due to an internal error
+                    self.grad_averager.load_accumulators_into_averager_()
                 elif self.delay_grad_averaging:
                     # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
                     wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
@@ -529,6 +530,7 @@ class Optimizer(torch.optim.Optimizer):
                 self._tag_along_with_zero_weight(self.scheduled_grads)
             else:
                 logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
+                self._load_local_gradients_into_optimizer()
                 self.scheduled_grads.cancel()
             self.scheduled_grads = None
         return began_averaging_gradients
@@ -597,9 +599,7 @@ class Optimizer(torch.optim.Optimizer):
             logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
 
         if not averaged_gradients:
-            logger.log(self.status_loglevel, f"Proceeding with local gradients")
-            self.grad_averager.load_accumulators_into_averager_()
-            self._load_averaged_gradients_into_optimizer_()
+            self._load_local_gradients_into_optimizer()
 
     def _load_averaged_gradients_into_optimizer_(self):
         """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
@@ -618,6 +618,15 @@ class Optimizer(torch.optim.Optimizer):
 
         self.grad_averager.notify_used_averaged_gradients()
 
+    def _load_local_gradients_into_optimizer(self):
+        """Fallback to using local gradients in the optimizer (instead of averaged gradients)"""
+        logger.log(self.status_loglevel, f"Proceeding with local gradients")
+        self.grad_averager.load_accumulators_into_averager_()
+        # note: we load gradients into grad_averager even though there is only one peer because of two reasons:
+        # - if offload_optimizer, then we must load gradients onto the CPU gradient buffers used by the optimizer
+        # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
+        self._load_averaged_gradients_into_optimizer_()
+
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
         if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers: