Kaynağa Gözat

black-isort

justheuristic 3 yıl önce
ebeveyn
işleme
1a6aa77f7d

+ 1 - 1
hivemind/optim/experimental/optimizer.py

@@ -22,7 +22,7 @@ from hivemind.optim.experimental.state_averager import (
     TrainingStateAverager,
 )
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_dht_time, get_logger, DHTExpiration
+from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 

+ 2 - 3
hivemind/optim/experimental/state_averager.py

@@ -1,10 +1,10 @@
 """ An extension of averager that supports common optimization use cases. """
 import logging
+import threading
 import time
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from itertools import chain
-import threading
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
@@ -14,7 +14,7 @@ from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_logger, nested_flatten, nested_pack, get_dht_time, DHTExpiration, PerformanceEMA
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 
@@ -589,7 +589,6 @@ class TrainingStateAverager(DecentralizedAverager):
                     delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
                     local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
 
-
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.

+ 1 - 1
hivemind/utils/__init__.py

@@ -5,7 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
-from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.timed_storage import *