Browse Source

fix decorator; avoid loading model and ref_model at the same time

justheuristic 3 years ago
parent
commit
653a85e6a2
1 changed files with 35 additions and 31 deletions
  1. 35 31
      tests/test_full_model.py

+ 35 - 31
tests/test_full_model.py

@@ -23,41 +23,45 @@ if not MODEL_NAME:
 REF_NAME = os.environ.get("REF_NAME")
 
 
-@torch.inference_mode
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
-    assert not torch.is_grad_enabled()
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
-    parallel_outputs = model.forward(test_inputs).logits
-    assert torch.all(torch.isfinite(parallel_outputs))
-    logger.info("Forward outputs are finite")
-
-    if REF_NAME:
-        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-        # note: this creates a dummy mask to make the test compatible with older transformer versions
-        # prior to https://github.com/huggingface/transformers/pull/17837
-        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
-        del ref_model, ref_outputs, dummy_mask
-    else:
-        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
-
-    embs = model.transformer.word_embeddings(test_inputs)
-    embs = model.transformer.word_embeddings_layernorm(embs)
-    recurrent_outputs = []
-    with model.transformer.h.inference_session() as sess:
-        for t in range(embs.shape[1]):
-            recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
-    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
-    recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-    dictionary = model.transformer.word_embeddings.weight.t()
-    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-    recurrent_outputs = (recurrent_outputs @ dictionary).float()
-    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
-    logger.info("Inference is consistent with forward")
+
+    with torch.no_grad():
+        parallel_outputs = model.forward(test_inputs).logits
+        assert torch.all(torch.isfinite(parallel_outputs))
+        logger.info("Forward outputs are finite")
+
+        embs = model.transformer.word_embeddings(test_inputs)
+        embs = model.transformer.word_embeddings_layernorm(embs)
+        recurrent_outputs = []
+        with model.transformer.h.inference_session() as sess:
+            for t in range(embs.shape[1]):
+                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+        recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
+        recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
+
+        dictionary = model.transformer.word_embeddings.weight.t()
+        recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
+        recurrent_outputs = (recurrent_outputs @ dictionary).float()
+        assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
+        logger.info("Inference is consistent with forward")
+
+        del model, recurrent_outputs
+
+        if REF_NAME:
+            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+            # note: this creates a dummy mask to make the test compatible with older transformer versions
+            # prior to https://github.com/huggingface/transformers/pull/17837
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+            logger.warning(f"{type(model)}.forward is consistent with {type(ref_model)}.forward")
+            del ref_model, ref_outputs, dummy_mask
+        else:
+            logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+