Artem Chumachenko 3 years ago
parent
commit
b7e6a4f12e
3 changed files with 27 additions and 29 deletions
  1. 5 0
      src/client/remote_generation.py
  2. 4 2
      src/server/backend.py
  3. 18 27
      tests/test_full_model.py

+ 5 - 0
src/client/remote_generation.py

@@ -28,6 +28,7 @@ class RemoteGenerationMixin:
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
+        max_length: Optional[int] = None,
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
@@ -64,6 +65,10 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
+        if max_length is not None and max_new_tokens is None:
+            max_new_tokens = max_length - inputs.size(1)
+            assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
             inputs = torch.tensor([[bos_token_id]])

+ 4 - 2
src/server/backend.py

@@ -1,6 +1,6 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Sequence, Tuple
 from queue import Empty
+from typing import Sequence, Tuple
 
 import torch
 from hivemind import use_hivemind_log_handler
@@ -53,7 +53,9 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = InferenceTaskPool(self.inference_step, max_batch_size=4096, name=f"{self.name}_inference")
+        self.inference_pool = InferenceTaskPool(
+            self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
+        )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 18 - 27
tests/test_full_model.py

@@ -11,20 +11,12 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-@pytest.fixture
-def tokenizer():
-    return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-
-
-@pytest.fixture
-def model():
-    return DistributedBloomForCausalLM.from_pretrained(
+@pytest.mark.forked
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
-
-
-@pytest.mark.forked
-def test_full_model_exact_match(tokenizer, model, atol_forward=1e-3, atol_inference=1e-3):
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
@@ -65,31 +57,30 @@ def test_full_model_exact_match(tokenizer, model, atol_forward=1e-3, atol_infere
             assert False
 
 
-def test_greedy_generation(tokenizer, model, max_new_tokens=4):
+@pytest.mark.forked
+def test_greedy_generation(max_new_tokens=4):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
     inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
     remote_outputs = model.generate(
         inputs,
         max_new_tokens=max_new_tokens,
     )
-    hf_outputs = BloomForCausalLM.greedy_search(
-        model,
-        input_ids=inputs,
-        max_length=inputs.size(1) + max_new_tokens
-    )
+    hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
     assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
 
-    inputs_batch = tokenizer(
-        ["A cat sat on a mat", "A dog sat on a mat"],
-        return_tensors='pt',
-        padding=True
-    )["input_ids"]
+    inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
+        "input_ids"
+    ]
     remote_outputs_batch = model.generate(
         inputs_batch,
         max_new_tokens=max_new_tokens,
     )
     hf_outputs_batch = BloomForCausalLM.greedy_search(
-        model,
-        input_ids=inputs_batch,
-        max_length=inputs_batch.size(1) + max_new_tokens
+        model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
     )
-    assert torch.allclose(remote_outputs_batch, hf_outputs_batch), "Greedy search are not identical to HF in multibatch mode"
+    assert torch.allclose(
+        remote_outputs_batch, hf_outputs_batch
+    ), "Greedy search are not identical to HF in multibatch mode"