artek0chumak 2 years ago
parent
commit
a43265a6ea
1 changed files with 4 additions and 2 deletions
  1. 4 2
      tests/test_full_model.py

+ 4 - 2
tests/test_full_model.py

@@ -14,7 +14,9 @@ logger = get_logger(__file__)
 @pytest.mark.forked
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
 @pytest.mark.parametrize("second_token_attention_mask", (1, 0))
-def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3):
+def test_full_model_exact_match(
+    pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3
+):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@@ -40,7 +42,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention
                 recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
 
             for t in range(embs.shape[1]):
-                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1]))
+                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, : t + 1]))
                 if t == int(embs.shape[1] // 2) and pass_empty_tensors:
                     recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
                     recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))