Sfoglia il codice sorgente

Added primitives for speculative decoding and tests (#598)

This PR creates a DistributedLlamaModelForSpeculativeGeneration that implements basic speculative decoding (currently for greedy inference only).
Anton Sinitsin 1 anno fa
parent
commit
02bbd85ed8

+ 21 - 15
src/petals/client/inference_session.py

@@ -83,6 +83,17 @@ class _ServerInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
+    @property
+    def position(self):
+        return self._position
+
+    @position.setter
+    def position(self, start_from_position: int):
+        assert start_from_position <= self._position
+        self._position = start_from_position
+        if self.history is not None and self.history.shape[1] >= start_from_position:
+            self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
+
     def step(
         self,
         inputs: torch.Tensor,
@@ -90,7 +101,6 @@ class _ServerInferenceSession:
         hypo_ids: torch.LongTensor,
         *,
         step_id: str,
-        start_from_position: int,
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -100,12 +110,6 @@ class _ServerInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
-        if start_from_position is not None:
-            assert start_from_position <= self._position
-            self._position = start_from_position
-            if self.history is not None and self.history.shape[1] >= start_from_position:
-                self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
-
         n_input_tokens = inputs.shape[1]
         if self.history is None:
             self.history = inputs
@@ -127,8 +131,8 @@ class _ServerInferenceSession:
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
             request_metadata.update(self.session_metadata)
-        if start_from_position is not None:
-            request_metadata["start_from_position"] = start_from_position
+        if self._position is not None:
+            request_metadata["start_from_position"] = self._position
         elif self.config.use_server_to_server:
             next_servers = self._collect_next_servers()
             if next_servers:
@@ -235,6 +239,13 @@ class InferenceSession:
     def position(self) -> int:
         return self._position
 
+    @position.setter
+    def position(self, start_from_position: int) -> None:
+        self._position = start_from_position
+        for session in self._server_sessions:
+            assert isinstance(session, _ServerInferenceSession)
+            session.position = start_from_position
+
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
         server_sessions = []
         try:
@@ -275,12 +286,7 @@ class InferenceSession:
         inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
-        start_from_position: Optional[int] = None,
     ) -> torch.Tensor:
-
-        if start_from_position is not None:
-            self._position = start_from_position
-
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -324,12 +330,12 @@ class InferenceSession:
                         self._update_sequence(server_idx, block_idx, attempt_no)
 
                     server_session = self._server_sessions[server_idx]
+                    assert server_session.position == self.position, f"{server_session.position} and {self.position}"
                     inputs = server_session.step(
                         inputs,
                         prompts[server_session.span.start : server_session.span.end],
                         hypo_ids,
                         step_id=step_id,
-                        start_from_position=start_from_position,
                     )
 
                     server_idx += 1

+ 2 - 0
src/petals/models/llama/__init__.py

@@ -5,11 +5,13 @@ from petals.models.llama.model import (
     DistributedLlamaForSequenceClassification,
     DistributedLlamaModel,
 )
+from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
 from petals.utils.auto_config import register_model_classes
 
 register_model_classes(
     config=DistributedLlamaConfig,
     model=DistributedLlamaModel,
     model_for_causal_lm=DistributedLlamaForCausalLM,
+    model_for_speculative=DistributedLlamaForSpeculativeGeneration,
     model_for_sequence_classification=DistributedLlamaForSequenceClassification,
 )

+ 111 - 0
src/petals/models/llama/speculative_model.py

@@ -0,0 +1,111 @@
+from typing import Optional, Union
+
+import torch
+from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
+from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama import LlamaForCausalLM
+
+from petals.models.llama.config import DistributedLlamaConfig
+from petals.models.llama.model import DistributedLlamaForCausalLM
+
+
+class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
+    def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
+        DistributedLlamaForCausalLM.__init__(self, config)
+        self.small_model = small_model
+
+    def _sample(
+        self,
+        input_ids: torch.LongTensor,
+        logits_processor: LogitsProcessorList,
+        stopping_criteria: StoppingCriteriaList,
+        generation_config: GenerationConfig,
+        synced_gpus: bool,
+        streamer: Optional["BaseStreamer"],
+        logits_warper: Optional[LogitsProcessorList],
+        speculative_inference_iteration_size: int = 10,
+        **model_kwargs,
+    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+        assert not generation_config.do_sample, "sample is not working for speculative generation now"
+        assert not synced_gpus, "synced_gpus is not working for speculative generation now"
+        assert (
+            not generation_config.return_dict_in_generate
+        ), "return_dict_in_generate is not working for speculative generation now"
+
+        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+
+        # keep track of which sequences are already finished
+        batch_size = input_ids.shape[0]
+        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
+        finished = False
+        firsts = True
+
+        while not finished:
+            speculative_inference_iteration_size = min(
+                speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
+            )
+            with torch.no_grad():
+                speculative_outputs = self.small_model.generate(
+                    input_ids,
+                    max_new_tokens=speculative_inference_iteration_size,
+                    do_sample=False,
+                )
+                speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]
+
+            full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
+            assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]
+
+            input_for_validation = full_sequence
+            if not firsts:
+                self.active_session.position = input_ids.shape[1] - 1
+                input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
+            else:
+                firsts = False
+            input_for_validation = input_for_validation[:, :-1]
+            with torch.no_grad():
+                precise_model_outputs = self(input_for_validation)
+            full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()
+
+            all_valid_tokens = []
+            first_token = None
+            for i in range(speculative_inference_iteration_size):
+                token_logits = full_token_logits[:, i, :]
+                token_scores = logits_processor(
+                    input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
+                )
+                valid_token = torch.argmax(token_scores, dim=-1)
+
+                if first_token is None:
+                    first_token = valid_token
+
+                if valid_token.item() == speculative_tokens[:, i].item():
+                    all_valid_tokens.append(valid_token.unsqueeze(-1))
+                else:
+                    break
+
+            if not all_valid_tokens and first_token is not None:
+                all_valid_tokens.append(first_token.unsqueeze(-1))
+            all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)
+
+            # finished sentences should have their next token be a padding token
+            if has_eos_stopping_criteria:
+                all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
+                    1 - unfinished_sequences
+                )
+
+            # update generated ids, model inputs, and length for next step
+            input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)
+
+            if streamer is not None:
+                streamer.put(all_valid_tokens.cpu())
+
+            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
+            finished = unfinished_sequences.max() == 0
+
+            del precise_model_outputs
+
+        if streamer is not None:
+            streamer.end()
+
+        return input_ids

+ 1 - 0
src/petals/utils/__init__.py

@@ -3,5 +3,6 @@ from petals.utils.auto_config import (
     AutoDistributedModel,
     AutoDistributedModelForCausalLM,
     AutoDistributedModelForSequenceClassification,
+    AutoDistributedSpeculativeModel,
 )
 from petals.utils.dht import declare_active_modules, get_remote_module_infos

+ 5 - 0
src/petals/utils/auto_config.py

@@ -15,6 +15,7 @@ class _ModelClasses:
     config: Type[PretrainedConfig]
     model: Optional[Type[PreTrainedModel]] = None
     model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
+    model_for_speculative: Optional[Type[PreTrainedModel]] = None
     model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
 
 
@@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
     _mapping_field = "model_for_causal_lm"
 
 
+class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
+    _mapping_field = "model_for_speculative"
+
+
 class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_sequence_classification"

+ 52 - 2
tests/test_speculative_generation.py

@@ -2,8 +2,14 @@ import random
 
 import pytest
 import torch
+import transformers
 
-from petals import AutoDistributedConfig, RemoteSequential
+from petals import (
+    AutoDistributedConfig,
+    AutoDistributedSpeculativeModel,
+    DistributedLlamaForSpeculativeGeneration,
+    RemoteSequential,
+)
 from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
 from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
@@ -26,10 +32,54 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     with torch.inference_mode():
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             initial_outputs_inference = sess.step(inputs)
-            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
+            sess.position = 2
+            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
 
     ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
     (outputs_local,) = ref_block(short_inputs)
 
     assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
+
+
+@pytest.fixture
+def noisy_model():
+    noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
+        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    lm_head = noisy_model.get_output_embeddings()
+    assert isinstance(lm_head, torch.nn.Linear)
+    with torch.no_grad():
+        lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
+    return noisy_model
+
+
+@pytest.fixture
+def model():
+    return transformers.AutoModelForCausalLM.from_pretrained(
+        MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+
+
+@pytest.fixture
+def tokenizer():
+    # We set use_fast=False since LlamaTokenizerFast is slow on load
+    return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
+
+
+@pytest.mark.forked
+@pytest.mark.skipif(
+    "llama" not in MODEL_NAME.lower(),
+    reason="Speculative generation now works only for llama models",
+)
+def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
+    speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
+    )
+
+    inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+
+    generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
+    generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)
+
+    assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)