Ver Fonte

set dtype='auto'

justheuristic há 3 anos atrás
pai
commit
dfe57ebf9a
1 ficheiros alterados com 3 adições e 6 exclusões
  1. 3 6
      tests/test_full_model.py

+ 3 - 6
tests/test_full_model.py

@@ -13,7 +13,7 @@ logger = get_logger(__file__)
 @pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype="auto")
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
@@ -32,17 +32,14 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
                 recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
         recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-        dictionary = model.transformer.word_embeddings.weight.t()
-        recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-        recurrent_outputs = (recurrent_outputs @ dictionary).float()
+        recurrent_outputs = model.lm_head(recurrent_outputs)
         assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
         logger.info("Inference is consistent with forward")
 
         del model, recurrent_outputs
 
         if REF_NAME:
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME, torch_dtype="auto")
             dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
             # note: this creates a dummy mask to make the test compatible with older transformer versions
             # prior to https://github.com/huggingface/transformers/pull/17837