Kaynağa Gözat

black-isort

justheuristic 3 yıl önce
ebeveyn
işleme
818a3a59f6

+ 2 - 2
benchmarks/benchmark_optimizer.py

@@ -1,6 +1,7 @@
 import multiprocessing as mp
 import random
 import time
+from contextlib import nullcontext
 from dataclasses import dataclass
 from functools import partial
 from typing import Callable
@@ -15,7 +16,6 @@ from torch.utils.data import Dataset
 import hivemind
 from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.utils.crypto import RSAPrivateKey
-from contextlib import nullcontext
 
 
 @dataclass(frozen=True)
@@ -151,4 +151,4 @@ def benchmark_optimizer(args: TrainingArguments):
             peer.join()
     finally:
         for peer in peers[1:]:
-            peer.kill()
+            peer.kill()

+ 1 - 1
hivemind/optim/experimental/state_averager.py

@@ -545,7 +545,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
             if self.offload_optimizer:
                 optimized_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-                loaded_parameters = loaded_parameters_and_extras[:len(optimized_parameters)]
+                loaded_parameters = loaded_parameters_and_extras[: len(optimized_parameters)]
                 for local_param, loaded_param in zip(optimized_parameters, loaded_parameters):
                     local_param.copy_(loaded_param, non_blocking=True)
 

+ 5 - 4
hivemind/optim/grad_scaler.py

@@ -1,11 +1,11 @@
 import contextlib
+import threading
 from copy import deepcopy
 from typing import Dict, Optional
-import threading
 
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state, OptState
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
 from torch.optim import Optimizer as TorchOptimizer
 
 import hivemind
@@ -56,8 +56,9 @@ class GradScaler(TorchGradScaler):
                 if self._is_ready_to_update:
                     logger.warning("Please call grad_scaler.update() after each step.")
                 assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
-                assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
-                    "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
+                assert (
+                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
                 if self.are_grads_finite(optimizer, use_cached=True):
                     super().step(optimizer, *args, **kwargs)
                 else: