|
@@ -1,27 +1,14 @@
|
|
-import os
|
|
|
|
-
|
|
|
|
import torch
|
|
import torch
|
|
-import transformers
|
|
|
|
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
|
|
|
from src import RemoteSequential
|
|
from src import RemoteSequential
|
|
-from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM
|
|
|
|
|
|
+from src.client.remote_model import DistributedBloomConfig
|
|
|
|
+from test_utils import *
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
|
|
|
|
-if not INITIAL_PEERS:
|
|
|
|
- raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
|
|
|
|
-INITIAL_PEERS = INITIAL_PEERS.split()
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-MODEL_NAME = os.environ.get("MODEL_NAME")
|
|
|
|
-if not MODEL_NAME:
|
|
|
|
- raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def test_remote_sequential():
|
|
def test_remote_sequential():
|
|
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
|
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|