justheuristic 3 лет назад
Родитель
Сommit
32dd570f55

+ 1 - 1
hivemind/__init__.py

@@ -18,7 +18,7 @@ from hivemind.optim import (
     DecentralizedSGD,
     TrainingAverager,
     GradScaler,
-    Optimizer
+    Optimizer,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *

+ 8 - 6
hivemind/optim/experimental/optimizer.py

@@ -187,11 +187,13 @@ class Optimizer(torch.optim.Optimizer):
         """If true, peer will discard local progress and attempt to download state from peers."""
         return self.local_epoch < self.tracker.global_epoch - self.epoch_tolerance
 
-    def step(self,
-             closure: Optional[Callable[[], torch.Tensor]] = None,
-             batch_size: Optional[int] = None,
-             grad_scaler: Optional[HivemindGradScaler] = None,
-             **kwargs):
+    def step(
+        self,
+        closure: Optional[Callable[[], torch.Tensor]] = None,
+        batch_size: Optional[int] = None,
+        grad_scaler: Optional[HivemindGradScaler] = None,
+        **kwargs,
+    ):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
 
@@ -358,7 +360,7 @@ class Optimizer(torch.optim.Optimizer):
         param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
         for param_group in param_groups:
             num_params = len(param_group["params"])
-            main_params_for_group = self.state_averager.main_parameters[next_index: next_index + num_params]
+            main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
             param_group["params"] = main_params_for_group
             next_index += num_params
         assert next_index == len(self.state_averager.main_parameters)

+ 7 - 3
hivemind/optim/experimental/state_averager.py

@@ -11,7 +11,6 @@ import torch
 import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import CompressionInfo, TensorRole
-from hivemind.optim.grad_scaler import GradScaler
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
@@ -284,7 +283,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
-        grad_scaler: Optional[GradScaler] = None,
+        grad_scaler: Optional[hivemind.GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
         """
@@ -380,7 +379,12 @@ class TrainingStateAverager(DecentralizedAverager):
         return output
 
     def _do(
-        self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
+        self,
+        optimizer_step: bool,
+        zero_grad: bool,
+        averaging_round: bool,
+        grad_scaler: Optional[hivemind.GradScaler],
+        **kwargs,
     ):
         """
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.