Selaa lähdekoodia

actually unscale

justheuristic 3 vuotta sitten
vanhempi
commit
dc76d68106
1 muutettua tiedostoa jossa 8 lisäystä ja 5 poistoa
  1. 8 5
      hivemind/optim/grad_scaler.py

+ 8 - 5
hivemind/optim/grad_scaler.py

@@ -3,7 +3,7 @@ from typing import Dict, Optional
 
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
+from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state, OptState
 from torch.optim import Optimizer as TorchOptimizer
 
 import hivemind
@@ -37,6 +37,7 @@ class GradScaler(TorchGradScaler):
         assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
         if self._is_running_global_step:
             super().unscale_(optimizer)
+            self._per_optimizer_states[id(optimizer.opt)] = self._per_optimizer_states[id(optimizer)]
             return True
         else:
             self._check_inf_per_device(optimizer)
@@ -45,12 +46,14 @@ class GradScaler(TorchGradScaler):
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
-            if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
-                optimizer = optimizer.opt
-            super().step(optimizer, *args, **kwargs)
+            if self.are_grads_finite(optimizer):
+                assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED
+                super().step(optimizer.opt, *args, **kwargs)
+            else:
+                logger.warning("Skipping global step due to gradient over/underflow")
             return True
         else:
-            super().step(optimizer, *args, **kwargs)
+            super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False