Browse Source

black-isort

justheuristic 3 years ago
parent
commit
31163a95f8

+ 1 - 1
tests/test_block_exact_match.py

@@ -3,12 +3,12 @@ import random
 import hivemind
 import torch
 import transformers
+from test_utils import *
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
-from test_utils import *
 
 
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):

+ 1 - 1
tests/test_chained_calls.py

@@ -8,11 +8,11 @@ import hivemind
 import torch
 import transformers
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
+from test_utils import *
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.dht_utils import get_remote_module
-from test_utils import *
 
 
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):

+ 0 - 1
tests/test_full_model.py

@@ -9,7 +9,6 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)

+ 1 - 1
tests/test_remote_sequential.py

@@ -1,9 +1,9 @@
 import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
+from test_utils import *
 
 from src import RemoteSequential
 from src.client.remote_model import DistributedBloomConfig
-from test_utils import *
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)