浏览代码

how bout now?

justheuristic 3 年之前
父节点
当前提交
bac27a7ac8
共有 3 个文件被更改,包括 6 次插入3 次删除
  1. 0 3
      .github/workflows/run-tests.yaml
  2. 1 0
      cli/convert_model.py
  3. 5 0
      tests/test_full_model.py

+ 0 - 3
.github/workflows/run-tests.yaml

@@ -98,9 +98,6 @@ jobs:
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
           SERVER4_PID=$!
 
-          tail -f server*.log &
-          LOGS_PID=$!
-
           sleep 60  # wait for server to download layers
           
           PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v

+ 1 - 0
cli/convert_model.py

@@ -60,6 +60,7 @@ if __name__ == "__main__":
     if args.resize_token_embeddings:
         logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
         model.resize_token_embeddings(args.resize_token_embeddings)
+        config.vocab_size = args.resize_token_embeddings
 
     tokenizer = transformers.AutoTokenizer.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision

+ 5 - 0
tests/test_full_model.py

@@ -17,6 +17,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
+    config = model.config
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
@@ -45,6 +46,10 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
             ref_model = transformers.BloomForCausalLM.from_pretrained(
                 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
             )
+            if config.vocab_size < ref_model.config.vocab_size:
+                ref_model.resize_token_embeddings(config.vocab_size)
+                logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
+
             dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
             # note: this creates a dummy mask to make the test compatible with older transformer versions
             # prior to https://github.com/huggingface/transformers/pull/17837