Browse Source

black-isort

justheuristic 3 years ago
parent
commit
561a0df751
1 changed files with 12 additions and 6 deletions
  1. 12 6
      tests/test_optimizer.py

+ 12 - 6
tests/test_optimizer.py

@@ -1,8 +1,8 @@
 import ctypes
 import multiprocessing as mp
+import random
 import time
 from functools import partial
-import random
 
 import numpy as np
 import pytest
@@ -13,9 +13,9 @@ import torch.nn.functional as F
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
-from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.utils.crypto import RSAPrivateKey
 
 
@@ -287,9 +287,16 @@ def test_progress_tracker():
     assert tracker.performance_ema.samples_per_second < 1e-9
 
 
-def test_optimizer(num_peers: int = 2, num_clients: int = 1, target_batch_size: int = 64, total_epochs: int = 3,
-                   reuse_grad_buffers: bool = True, delay_grad_averaging: bool = True,
-                   delay_optimizer_step: bool = True, average_state_every: int = 1):
+def test_optimizer(
+    num_peers: int = 2,
+    num_clients: int = 1,
+    target_batch_size: int = 64,
+    total_epochs: int = 3,
+    reuse_grad_buffers: bool = True,
+    delay_grad_averaging: bool = True,
+    delay_optimizer_step: bool = True,
+    average_state_every: int = 1,
+):
     dht = hivemind.DHT(start=True)
 
     features = torch.randn(100, 5)
@@ -367,4 +374,3 @@ def test_optimizer(num_peers: int = 2, num_clients: int = 1, target_batch_size:
 
     assert isinstance(optimizer, Optimizer)
     assert optimizer.local_epoch == total_epochs
-