justheuristic преди 3 години
родител
ревизия
b7e1c94139
променени са 3 файла, в които са добавени 5 реда и са изтрити 5 реда
  1. 1 1
      hivemind/__init__.py
  2. 1 1
      hivemind/optim/experimental/optimizer.py
  3. 3 3
      hivemind/optim/experimental/state_averager.py

+ 1 - 1
hivemind/__init__.py

@@ -16,9 +16,9 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
-    TrainingAverager,
     GradScaler,
     Optimizer,
+    TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *

+ 1 - 1
hivemind/optim/experimental/optimizer.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import logging
 import os
-from typing import Optional, Union, Callable
+from typing import Callable, Optional, Union
 
 import torch
 

+ 3 - 3
hivemind/optim/experimental/state_averager.py

@@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
 
 import torch
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.optim.grad_scaler import GradScaler
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
@@ -283,7 +283,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
-        grad_scaler: Optional[hivemind.GradScaler] = None,
+        grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
         """
@@ -383,7 +383,7 @@ class TrainingStateAverager(DecentralizedAverager):
         optimizer_step: bool,
         zero_grad: bool,
         averaging_round: bool,
-        grad_scaler: Optional[hivemind.GradScaler],
+        grad_scaler: Optional[GradScaler],
         **kwargs,
     ):
         """