|
@@ -1,10 +1,10 @@
|
|
|
""" An extension of averager that supports common optimization use cases. """
|
|
|
import logging
|
|
|
+import threading
|
|
|
import time
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from contextlib import nullcontext
|
|
|
from itertools import chain
|
|
|
-import threading
|
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
@@ -14,7 +14,7 @@ from hivemind.averaging import DecentralizedAverager
|
|
|
from hivemind.averaging.control import StepControl
|
|
|
from hivemind.compression import CompressionInfo, TensorRole
|
|
|
from hivemind.optim.grad_scaler import GradScaler
|
|
|
-from hivemind.utils import get_logger, nested_flatten, nested_pack, get_dht_time, DHTExpiration, PerformanceEMA
|
|
|
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -589,7 +589,6 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
delta = torch.sub(new_tensor, old_tensor, out=old_tensor) # using old tensors as buffers
|
|
|
local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
|
|
|
|
|
|
-
|
|
|
def get_current_state(self):
|
|
|
"""
|
|
|
Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
|