|
@@ -5,7 +5,7 @@ import transformers
|
|
|
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
|
|
|
from src import RemoteSequential
|
|
|
-from src.client.remote_model import DistributedBloomForCausalLM, DistributedBloomConfig
|
|
|
+from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -36,8 +36,8 @@ def test_remote_sequential():
|
|
|
full_grad = test_inputs.grad.clone()
|
|
|
test_inputs.grad.data.zero_()
|
|
|
|
|
|
- first_half = sequential[:config.n_layer // 2]
|
|
|
- second_half = sequential[config.n_layer // 2:]
|
|
|
+ first_half = sequential[: config.n_layer // 2]
|
|
|
+ second_half = sequential[config.n_layer // 2 :]
|
|
|
assert len(first_half) + len(second_half) == len(sequential)
|
|
|
assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
|
|
|
for m in sequential, first_half, second_half:
|