justheuristic 3 жил өмнө
parent
commit
c886472d73

+ 2 - 2
hivemind/optim/collaborative.py

@@ -245,7 +245,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
@@ -310,7 +310,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self.opt)
+                    assert grad_scaler.step(self)
             else:
                 self.opt.step()
 

+ 6 - 4
hivemind/optim/grad_scaler.py

@@ -36,17 +36,18 @@ class GradScaler(TorchGradScaler):
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
         assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
         if self._is_running_global_step:
-            super().unscale_(optimizer.opt)
+            super().unscale_(optimizer)
             return True
         else:
-            self._check_inf_per_device(optimizer.opt)
+            self._check_inf_per_device(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
             if self.are_grads_finite(optimizer):
-                super().step(optimizer, *args, **kwargs)
+                assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+                super().step(optimizer.opt, *args, **kwargs)
             else:
                 logger.warning("Skipping global step due to gradient over/underflow")
             return True
@@ -78,7 +79,8 @@ class GradScaler(TorchGradScaler):
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
     def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
+        assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+        return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())
 
 
 class HivemindGradScaler(GradScaler):