|
@@ -24,9 +24,10 @@ if not MODEL_NAME:
|
|
|
REF_NAME = os.environ.get("REF_NAME")
|
|
|
|
|
|
|
|
|
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
|
|
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
|
+ assert isinstance(model, DistributedBloomForCausalLM)
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|
|
|
|
|
|
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
@@ -35,26 +36,29 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
|
|
logger.info("Forward outputs are finite")
|
|
|
|
|
|
if REF_NAME:
|
|
|
- ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
|
|
|
- dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
- # note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
- # prior to https://github.com/huggingface/transformers/pull/17837
|
|
|
- ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
|
|
|
- assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
|
|
+ with torch.no_grad():
|
|
|
+ ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
|
|
|
+ dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
+ # note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
+ # prior to https://github.com/huggingface/transformers/pull/17837
|
|
|
+ ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
|
|
|
+ assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
|
|
+ del ref_model, ref_outputs
|
|
|
else:
|
|
|
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
|
|
|
|
|
|
- embs = model.transformer.word_embeddings(test_inputs)
|
|
|
- embs = model.transformer.word_embeddings_layernorm(embs)
|
|
|
- recurrent_outputs = []
|
|
|
- with model.transformer.h.inference_session() as sess:
|
|
|
- for t in range(embs.shape[1]):
|
|
|
- recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
|
|
- recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
|
|
|
- recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
|
|
|
-
|
|
|
- dictionary = model.transformer.word_embeddings.weight.t()
|
|
|
- recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
|
|
|
- recurrent_outputs = (recurrent_outputs @ dictionary).float()
|
|
|
+ with torch.inference_mode():
|
|
|
+ embs = model.transformer.word_embeddings(test_inputs)
|
|
|
+ embs = model.transformer.word_embeddings_layernorm(embs)
|
|
|
+ recurrent_outputs = []
|
|
|
+ with model.transformer.h.inference_session() as sess:
|
|
|
+ for t in range(embs.shape[1]):
|
|
|
+ recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
|
|
+ recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
|
|
|
+ recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
|
|
|
+
|
|
|
+ dictionary = model.transformer.word_embeddings.weight.t()
|
|
|
+ recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
|
|
|
+ recurrent_outputs = (recurrent_outputs @ dictionary).float()
|
|
|
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
|
|
|
logger.info("Inference is consistent with forward")
|