Kaynağa Gözat

Pack of Inference Changes (#37)

* Return multibatch mode

* Add tests

* fixes
Artem Chumachenko 3 yıl önce
ebeveyn
işleme
d989b94614

+ 6 - 0
src/client/remote_generation.py

@@ -17,6 +17,7 @@ class RemoteGenerationMixin:
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
     """
 
+    @torch.no_grad()
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
@@ -27,6 +28,7 @@ class RemoteGenerationMixin:
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
+        max_length: Optional[int] = None,
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
@@ -63,6 +65,10 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
+        if max_length is not None and max_new_tokens is None:
+            max_new_tokens = max_length - inputs.size(1)
+            assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
             inputs = torch.tensor([[bos_token_id]])

+ 1 - 30
src/client/remote_model.py

@@ -17,8 +17,6 @@ from src.bloom.model import (
 )
 from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
-from src.utils.generation_algorithms import DecodingAlgorithm
-from src.utils.generation_constraints import ABCBloomConstraint
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -156,7 +154,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return transformer_outputs
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     config_class = DistributedBloomConfig
@@ -190,33 +188,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
             self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
             self.lm_head.bias[...] = new_lm_head.bias
 
-    def generate(
-        self,
-        inputs: Optional[torch.Tensor] = None,
-        do_sample: Optional[bool] = None,
-        temperature: float = 1.0,
-        top_k: Optional[int] = None,
-        top_p: Optional[float] = None,
-        eos_token_id: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        decoding_algorithm: Optional[DecodingAlgorithm] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
-    ) -> torch.Tensor:
-        return RemoteGenerationMixin.generate(
-            self,
-            inputs=inputs,
-            do_sample=do_sample,
-            temperature=temperature,
-            top_k=top_k,
-            top_p=top_p,
-            eos_token_id=eos_token_id,
-            max_new_tokens=max_new_tokens,
-            decoding_algorithm=decoding_algorithm,
-            provided_constraints=provided_constraints,
-            **model_kwargs,
-        )
-
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig

+ 33 - 1
src/server/backend.py

@@ -1,16 +1,46 @@
 """Code for serving bloom blocks via hivemind-server"""
+from queue import Empty
 from typing import Sequence, Tuple
 
 import torch
+from hivemind import use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
+from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
 
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
 MAX_LENGTH = 2048
 
 
+class InferenceTaskPool(TaskPool):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
+
+    def iterate_minibatches(self, *args, **kwargs):
+        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+
+        while True:
+            try:
+                logger.debug(f"{self.name} getting next task")
+                task = self.tasks.get(timeout=self.timeout)
+            except Empty:
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
+
+            try:
+                if task.future.set_running_or_notify_cancel():
+                    yield [task]
+            except InvalidStateError as e:
+                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
+
+
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
@@ -23,7 +53,9 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
+        self.inference_pool = InferenceTaskPool(
+            self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
+        )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():

+ 30 - 0
tests/test_full_model.py

@@ -4,6 +4,7 @@ import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
 
+from src.bloom.model import BloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
@@ -54,3 +55,32 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         else:
             logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
             assert False
+
+
+@pytest.mark.forked
+def test_greedy_generation(max_new_tokens=4):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    remote_outputs = model.generate(
+        inputs,
+        max_new_tokens=max_new_tokens,
+    )
+    hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
+    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
+
+    inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
+        "input_ids"
+    ]
+    remote_outputs_batch = model.generate(
+        inputs_batch,
+        max_new_tokens=max_new_tokens,
+    )
+    hf_outputs_batch = BloomForCausalLM.greedy_search(
+        model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
+    )
+    assert torch.allclose(
+        remote_outputs_batch, hf_outputs_batch
+    ), "Greedy search are not identical to HF in multibatch mode"