Artem Chumachenko 3 anos atrás
pai
commit
9123ce57ac
2 arquivos alterados com 43 adições e 6 exclusões
  1. 0 2
      src/client/remote_model.py
  2. 43 4
      tests/test_full_model.py

+ 0 - 2
src/client/remote_model.py

@@ -17,8 +17,6 @@ from src.bloom.model import (
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
-from src.utils.generation_algorithms import DecodingAlgorithm
-from src.utils.generation_constraints import ABCBloomConstraint
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 43 - 4
tests/test_full_model.py

@@ -4,18 +4,27 @@ import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
 
+from src.bloom.model import BloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-@pytest.mark.forked
-def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
-    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(
+@pytest.fixture
+def tokenizer():
+    return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture
+def model():
+    return DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
+
+
+@pytest.mark.forked
+def test_full_model_exact_match(tokenizer, model, atol_forward=1e-3, atol_inference=1e-3):
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
@@ -54,3 +63,33 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         else:
             logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
             assert False
+
+
+def test_greedy_generation(tokenizer, model, max_new_tokens=4):
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    remote_outputs = model.generate(
+        inputs,
+        max_new_tokens=max_new_tokens,
+    )
+    hf_outputs = BloomForCausalLM.greedy_search(
+        model,
+        input_ids=inputs,
+        max_length=inputs.size(1) + max_new_tokens
+    )
+    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
+
+    inputs_batch = tokenizer(
+        ["A cat sat on a mat", "A dog sat on a mat"],
+        return_tensors='pt',
+        padding=True
+    )["input_ids"]
+    remote_outputs_batch = model.generate(
+        inputs_batch,
+        max_new_tokens=max_new_tokens,
+    )
+    hf_outputs_batch = BloomForCausalLM.greedy_search(
+        model,
+        input_ids=inputs_batch,
+        max_length=inputs_batch.size(1) + max_new_tokens
+    )
+    assert torch.allclose(remote_outputs_batch, hf_outputs_batch), "Greedy search are not identical to HF in multibatch mode"