Sfoglia il codice sorgente

PR attribution (from aux peers)

Michael Diskin 3 anni fa
parent
commit
da440f17aa
1 ha cambiato i file con 1 aggiunte e 0 eliminazioni
  1. 1 0
      hivemind/optim/experimental/grad_averager.py

+ 1 - 0
hivemind/optim/experimental/grad_averager.py

@@ -208,6 +208,7 @@ class GradientAverager(DecentralizedAverager):
     @contextlib.contextmanager
     @torch.no_grad()
     def use_averaged_gradients(self):
+        """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
             assert len(averaged_grads) == len(self.parameters)