|
@@ -1,8 +1,8 @@
|
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
|
+import random
|
|
|
import time
|
|
|
from functools import partial
|
|
|
-import random
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
@@ -13,9 +13,9 @@ import torch.nn.functional as F
|
|
|
import hivemind
|
|
|
from hivemind.averaging.control import AveragingStage
|
|
|
from hivemind.optim.experimental.grad_averager import GradientAverager
|
|
|
+from hivemind.optim.experimental.optimizer import Optimizer
|
|
|
from hivemind.optim.experimental.progress_tracker import ProgressTracker
|
|
|
from hivemind.optim.experimental.state_averager import TrainingStateAverager
|
|
|
-from hivemind.optim.experimental.optimizer import Optimizer
|
|
|
from hivemind.utils.crypto import RSAPrivateKey
|
|
|
|
|
|
|
|
@@ -287,9 +287,16 @@ def test_progress_tracker():
|
|
|
assert tracker.performance_ema.samples_per_second < 1e-9
|
|
|
|
|
|
|
|
|
-def test_optimizer(num_peers: int = 2, num_clients: int = 1, target_batch_size: int = 64, total_epochs: int = 3,
|
|
|
- reuse_grad_buffers: bool = True, delay_grad_averaging: bool = True,
|
|
|
- delay_optimizer_step: bool = True, average_state_every: int = 1):
|
|
|
+def test_optimizer(
|
|
|
+ num_peers: int = 2,
|
|
|
+ num_clients: int = 1,
|
|
|
+ target_batch_size: int = 64,
|
|
|
+ total_epochs: int = 3,
|
|
|
+ reuse_grad_buffers: bool = True,
|
|
|
+ delay_grad_averaging: bool = True,
|
|
|
+ delay_optimizer_step: bool = True,
|
|
|
+ average_state_every: int = 1,
|
|
|
+):
|
|
|
dht = hivemind.DHT(start=True)
|
|
|
|
|
|
features = torch.randn(100, 5)
|
|
@@ -367,4 +374,3 @@ def test_optimizer(num_peers: int = 2, num_clients: int = 1, target_batch_size:
|
|
|
|
|
|
assert isinstance(optimizer, Optimizer)
|
|
|
assert optimizer.local_epoch == total_epochs
|
|
|
-
|