|
@@ -1,4 +1,5 @@
|
|
|
import contextlib
|
|
|
+from copy import deepcopy
|
|
|
from typing import Dict, Optional
|
|
|
|
|
|
import torch
|
|
@@ -37,7 +38,7 @@ class GradScaler(TorchGradScaler):
|
|
|
assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
if self._is_running_global_step:
|
|
|
super().unscale_(optimizer)
|
|
|
- self._per_optimizer_states[id(optimizer.opt)] = self._per_optimizer_states[id(optimizer)]
|
|
|
+ self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
return True
|
|
|
else:
|
|
|
self._check_inf_per_device(optimizer)
|