瀏覽代碼

Add missing methods for SamplingAlgorithm, fix docstrings (#107)

* Add missing methods for SamplingAlgorithm, fix docstrings

* Add SamplingAlgorithm to _choose_sample_algorithm

* Add test_sampling

* Add a warning if sampling options were passed, but do_sample=False

* Skip the sampling test for now

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Max Ryabinin 2 年之前
父節點
當前提交
bd91be27ea
共有 3 個文件被更改,包括 82 次插入22 次删除
  1. 10 4
      src/petals/client/remote_generation.py
  2. 24 16
      src/petals/utils/generation_algorithms.py
  3. 48 2
      tests/test_full_model.py

+ 10 - 4
src/petals/client/remote_generation.py

@@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import (
     DecodingAlgorithm,
     GreedyAlgorithm,
     NucleusAlgorithm,
+    SamplingAlgorithm,
     TopKAlgorithm,
 )
 from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
@@ -22,7 +23,7 @@ class RemoteGenerationMixin:
     A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
     The class exposes can be used for:
         - *greedy decoding*.
-        - *multinomial sampling*.
+        - *multinomial, top-k and top-p sampling*.
         - *beam-search decoding*
 
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
@@ -126,6 +127,8 @@ class RemoteGenerationMixin:
             elif num_beams is not None and num_beams > 1:
                 decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
             else:
+                if top_k is not None or top_p is not None:
+                    logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling")
                 decoding_algorithm = GreedyAlgorithm()
 
         if num_beams > 1:
@@ -252,7 +255,8 @@ class RemoteGenerationMixin:
         **model_kwargs,
     ) -> torch.LongTensor:
         """
-        Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
+        Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
+        If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
 
         :param: input_ids: The input tokens to the model.
         :param: temperature: The temperature to use for sampling.
@@ -341,10 +345,12 @@ class RemoteGenerationMixin:
     ) -> DecodingAlgorithm:
         if (top_k is not None) and (top_p is not None):
             raise ValueError("You have to provide only top_k or top_p for sampling")
-        if top_k:
+        if top_k is not None:
             return TopKAlgorithm(top_k, temperature)
-        elif top_p:
+        elif top_p is not None:
             return NucleusAlgorithm(top_p, temperature)
+        else:
+            return SamplingAlgorithm(temperature)
 
     def _get_constraints(
         self,

+ 24 - 16
src/petals/utils/generation_algorithms.py

@@ -1,4 +1,4 @@
-from abc import ABC
+from abc import ABC, abstractmethod
 from typing import Tuple
 
 import torch
@@ -9,16 +9,16 @@ HypoIds = torch.Tensor
 
 class DecodingAlgorithm(ABC):
     """
-    An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
+    An abstract class for decoding algorithms. Describes the base function of those algorithms:
+    they have to select new tokens and provide the corresponding hypotheses.
     """
 
-    def __init__(self) -> None:
-        pass
-
+    @abstractmethod
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
         :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
-        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
+        :return: A tuple of selected token ids and corresponding hypotheses.
+        The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
         """
         pass
 
@@ -30,27 +30,36 @@ class GreedyAlgorithm(DecodingAlgorithm):
 
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
-        Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
+        Returns the most probable token. The second returned object is always a range of integers
+        from 0 to batch_size - 1.
         """
         return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
 
 
 class SamplingAlgorithm(DecodingAlgorithm):
+    def __init__(self, temperature: float = 1.0):
+        self.temperature = temperature
+
     def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
         :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
         :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
-        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
+        :return: A tuple of selected token ids and corresponding hypotheses.
+        The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
         """
         logits[indices_to_remove] = -float("Inf")
         probs = torch.softmax(logits / self.temperature, -1)
         return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
 
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
+        return self.sample(logits, indices_to_remove)
+
 
 class TopKAlgorithm(SamplingAlgorithm):
     def __init__(self, top_k: int, temperature: float = 1.0) -> None:
+        super().__init__(temperature=temperature)
         self.top_k = top_k
-        self.temperature = temperature
 
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
@@ -59,18 +68,17 @@ class TopKAlgorithm(SamplingAlgorithm):
 
 class NucleusAlgorithm(SamplingAlgorithm):
     def __init__(self, top_p: float, temperature: float = 1.0) -> None:
+        super().__init__(temperature=temperature)
         self.top_p = top_p
-        self.temperature = temperature
 
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
         probs = torch.softmax(sorted_logits / self.temperature, -1)
         cumulative_probs = torch.cumsum(probs, dim=-1)
-        sorted_indices_to_remove = cumulative_probs > self.top_p
-        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
-        sorted_indices_to_remove[..., 0] = False
-        indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
-        indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
+
+        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
+
+        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
         return self.sample(logits, indices_to_remove)
 
 

+ 48 - 2
tests/test_full_model.py

@@ -83,7 +83,7 @@ def test_greedy_generation(max_new_tokens=4):
         max_new_tokens=max_new_tokens,
     )
     hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
-    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
+    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
 
     inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
         "input_ids"
@@ -97,7 +97,53 @@ def test_greedy_generation(max_new_tokens=4):
     )
     assert torch.allclose(
         remote_outputs_batch, hf_outputs_batch
-    ), "Greedy search are not identical to HF in multibatch mode"
+    ), "Greedy search results are not identical to HF in multibatch mode"
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
+@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
+def test_sampling(sampling_options, max_new_tokens=4):
+    torch.manual_seed(0)
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    with torch.random.fork_rng():
+        remote_outputs = model.generate(
+            inputs,
+            max_new_tokens=max_new_tokens,
+            do_sample=True,
+            **sampling_options,
+        )
+    with torch.random.fork_rng():
+        hf_outputs = BloomForCausalLM.sample(
+            model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
+        )
+    assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
+
+    inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
+        "input_ids"
+    ]
+    with torch.random.fork_rng():
+        remote_outputs_batch = model.generate(
+            inputs_batch,
+            max_new_tokens=max_new_tokens,
+            do_sample=True,
+            **sampling_options,
+        )
+    with torch.random.fork_rng():
+        hf_outputs_batch = BloomForCausalLM.sample(
+            model,
+            input_ids=inputs_batch,
+            max_length=inputs_batch.size(1) + max_new_tokens,
+            logits_warper=logits_warper,
+        )
+    assert torch.allclose(
+        remote_outputs_batch, hf_outputs_batch
+    ), "Sampling results are not identical to HF in multibatch mode"
 
 
 @pytest.mark.forked