|
@@ -80,7 +80,7 @@ class GradScaler(TorchGradScaler):
|
|
|
|
|
|
def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
|
|
|
assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())
|
|
|
+ return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
|
|
|
|
|
|
|
|
|
class HivemindGradScaler(GradScaler):
|