Przeglądaj źródła

unify environment variable names in all tests

justheuristic 3 lat temu
rodzic
commit
25a4b23640

+ 1 - 12
tests/test_block_exact_match.py

@@ -1,9 +1,6 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
 import random
 import random
 
 
 import hivemind
 import hivemind
-import pytest
 import torch
 import torch
 import transformers
 import transformers
 
 
@@ -11,15 +8,7 @@ from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
 from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
+from test_utils import *
 
 
 
 
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):

+ 1 - 11
tests/test_chained_calls.py

@@ -3,7 +3,6 @@
 # - if you want more stable tests, see test_block_exact_match
 # - if you want more stable tests, see test_block_exact_match
 # - if you want to figure out chained inference, ask yozh
 # - if you want to figure out chained inference, ask yozh
 
 
-import os
 
 
 import hivemind
 import hivemind
 import torch
 import torch
@@ -13,16 +12,7 @@ from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
 from src.dht_utils import get_remote_module
 from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
+from test_utils import *
 
 
 
 
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):

+ 1 - 14
tests/test_full_model.py

@@ -1,8 +1,7 @@
-import os
-
 import torch
 import torch
 import transformers
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from hivemind import get_logger, use_hivemind_log_handler
+from test_utils import *
 
 
 from src.client.remote_model import DistributedBloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
 
 
@@ -10,18 +9,6 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME")
-
 
 
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)

+ 2 - 15
tests/test_remote_sequential.py

@@ -1,27 +1,14 @@
-import os
-
 import torch
 import torch
-import transformers
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
 
 from src import RemoteSequential
 from src import RemoteSequential
-from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM
+from src.client.remote_model import DistributedBloomConfig
+from test_utils import *
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
-
-
 def test_remote_sequential():
 def test_remote_sequential():
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)

+ 13 - 0
tests/test_utils.py

@@ -0,0 +1,13 @@
+import os
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME")