Przeglądaj źródła

use local accumulators

justheuristic 3 lat temu
rodzic
commit
b656f15d2a

+ 2 - 2
hivemind/optim/experimental/grad_averager.py

@@ -170,7 +170,7 @@ class GradientAverager(DecentralizedAverager):
         elif len(kwargs) > 0:
             raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
         assert not control.triggered, f"This {type(control)} instance was already used."
-        self._load_accumulators_into_averager_()
+        self.load_accumulators_into_averager_()
         self._accumulators_used_in_step = True
         self._new_averaged_grads = True
 
@@ -182,7 +182,7 @@ class GradientAverager(DecentralizedAverager):
         return control.result(timeout) if wait else control
 
     @torch.no_grad()
-    def _load_accumulators_into_averager_(self):
+    def load_accumulators_into_averager_(self):
         """load locally accumulated gradients into the averager for aggregation"""
         if self._new_averaged_grads and self.warn:
             logger.warning(

+ 3 - 1
hivemind/optim/experimental/optimizer.py

@@ -259,12 +259,14 @@ class Optimizer(torch.optim.Optimizer):
                     )
                     logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
                 except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging failed with {repr(e)}")
+                    logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}, using local grads")
+                    self.grad_averager.load_accumulators_into_averager_()
 
             else:
                 if self.scheduled_round is not None:
                     self.scheduled_round.cancel()
                 logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+                self.grad_averager.load_accumulators_into_averager_()
 
             assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
             with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):