浏览代码

black-isort

justheuristic 3 年之前
父节点
当前提交
52238e8e5c
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 3 3
      tests/test_remote_sequential.py

+ 3 - 3
tests/test_remote_sequential.py

@@ -5,7 +5,7 @@ import transformers
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
 from src import RemoteSequential
-from src.client.remote_model import DistributedBloomForCausalLM, DistributedBloomConfig
+from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -36,8 +36,8 @@ def test_remote_sequential():
     full_grad = test_inputs.grad.clone()
     test_inputs.grad.data.zero_()
 
-    first_half = sequential[:config.n_layer // 2]
-    second_half = sequential[config.n_layer // 2:]
+    first_half = sequential[: config.n_layer // 2]
+    second_half = sequential[config.n_layer // 2 :]
     assert len(first_half) + len(second_half) == len(sequential)
     assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
     for m in sequential, first_half, second_half: