|
@@ -208,6 +208,7 @@ class GradientAverager(DecentralizedAverager):
|
|
|
@contextlib.contextmanager
|
|
|
@torch.no_grad()
|
|
|
def use_averaged_gradients(self):
|
|
|
+ """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
|
|
|
self._new_averaged_grads = False
|
|
|
with self.get_tensors() as averaged_grads:
|
|
|
assert len(averaged_grads) == len(self.parameters)
|