justheuristic 3 年 前
コミット
5d5d6883ba
1 ファイル変更2 行追加1 行削除
  1. 2 1
      hivemind/optim/grad_scaler.py

+ 2 - 1
hivemind/optim/grad_scaler.py

@@ -1,4 +1,5 @@
 import contextlib
+from copy import deepcopy
 from typing import Dict, Optional
 
 import torch
@@ -37,7 +38,7 @@ class GradScaler(TorchGradScaler):
         assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
         if self._is_running_global_step:
             super().unscale_(optimizer)
-            self._per_optimizer_states[id(optimizer.opt)] = self._per_optimizer_states[id(optimizer)]
+            self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
             return True
         else:
             self._check_inf_per_device(optimizer)