Browse Source

more hf like

artek0chumak 3 năm trước cách đây
mục cha
commit
e6884291b7

+ 206 - 35
src/client/remote_generation.py

@@ -1,14 +1,21 @@
 import torch
+import torch.nn.functional as F
 
 from typing import List, Optional
 
 from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, TopKAlgorithm, NucleusAlgorithm
-from src.utils.generation_constraints import ABConstraint, MaxNewTokensConstraint
+from src.utils.generation_constraints import ABCBloomConstraint, MaxNewTokensConstraint, EosConstraint
 
-from transformers.modeling_utils import PreTrainedModel
 
-
-class RemoteGenerationMixin(PreTrainedModel):
+class RemoteGenerationMixin:
+    """
+    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
+    The class exposes can be used for:
+        - *greedy decoding*.
+        - *multinomial sampling*.
+    
+    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
+    """
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
@@ -16,50 +23,214 @@ class RemoteGenerationMixin(PreTrainedModel):
         temperature: float = 1.0,
         top_k: Optional[int] = None,
         top_p: Optional[float] = None,
+        bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
-        provided_constraints: List[ABConstraint] = [],
+        provided_constraints: List[ABCBloomConstraint] = [],
         **model_kwargs,
-    ) -> torch.Tensor:
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head.
+        
+        :param inputs: The input tokens to the model.
+        :param do_sample: Whether to sample from the model predictions or take the argmax.
+        :param temperature: The temperature to use for sampling.
+        :param top_k: The number of results to return.
+        :param top_p: The cumulative probability of results to return.
+        :param bos_token_id: The id of the beginning of sentence token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param pad_token_id: The id of the padding token.
+        :param max_new_tokens: The maximum number of tokens to generate.
+        :param decoding_algorithm: The decoding algorithm to use.
+        :param provided_constraints: A list of constraints to use.
+        :param model_kwargs: Additional arguments to pass to the model.
+        """
+
+        assert model_kwargs.get("logits_processor", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
+        assert model_kwargs.get("logits_wrapper", None) is None, "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
+        assert model_kwargs.get("stopping_criteria", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+
+        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+
+        word_embeddings = self.transformer.word_embeddings.weight
+
+        if inputs is None:
+            assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
+            inputs = torch.tensor([[bos_token_id]])
+
         if decoding_algorithm is None:
             if do_sample:
-                if (top_k is None) == (top_p is None):
-                    raise ValueError("You have to provide only top_k or top_p for sampling")
-                if top_k:
-                    decoding_algorithm = TopKAlgorithm(top_k, temperature)
-                elif top_p:
-                    decoding_algorithm = NucleusAlgorithm(top_p, temperature)
+                decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
             else:
                 decoding_algorithm = GreedyAlgorithm()
 
-        constraints = []
-        constraints.extend(provided_constraints)
-
-        if max_new_tokens and eos_token_id:
-            constraints.append(MaxNewTokensConstraint(max_new_tokens, eos_token_id))
-
-        for constraint in constraints:
-            constraint.consume_prefix(inputs)
-
-        word_embeddings = self.transformer.word_embeddings.weight.t()
+        constraints = self._get_constraints(
+            inputs=inputs,
+            eos_token_id=eos_token_id, 
+            pad_token_id=pad_token_id, 
+            max_new_tokens=max_new_tokens, 
+            provided_constraints=provided_constraints,
+        )
 
         with self.transformer.h.inference_session() as sess:
-            last_token_id = inputs[:, -1]
-            outputs = [inputs]
-            while torch.any(last_token_id != eos_token_id):
-                embs = self.transformer.word_embeddings(inputs)
+            outputs = []
+            if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
+                outputs += [inputs[:, :inputs.size(1) - (inputs == pad_token_id).sum(-1).min()]]
+            else:
+                outputs += [inputs]
+            last_token_id = None
+            seq_idx = outputs[0].size(1)
+            hypo_ids = torch.arange(outputs[0].size(0))
+            while True:
+                embs = self.transformer.word_embeddings(outputs[-1])
+                print(embs.size())
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                for emb_ids in range(embs.size(1)):
-                    recurrent_output = sess.step(embs[:, emb_ids:emb_ids+1])
-                recurrent_output = self.transformer.ln_f(recurrent_output)
-                lm_logits = (recurrent_output @ word_embeddings).float()
-                for constraint in constraints:
-                    lm_logits = constraint.calculate_transation(lm_logits)
-                last_token_id, _ = decoding_algorithm(lm_logits)
+                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = self.transformer.ln_f(hidden_state)
+                lm_logits = F.linear(hidden_state, word_embeddings).float()
+
                 for constraint in constraints:
-                    constraint.update(last_token_id, torch.ones_like(last_token_id))
+                    print(lm_logits.size())
+                    lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
+                last_token_id, hypo_ids = decoding_algorithm(lm_logits)
+                if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
+                    pad_token_mask = inputs[:, seq_idx] == pad_token_id
+                    last_token_id = (1 - pad_token_mask) * inputs[:, seq_idx] + pad_token_mask * last_token_id
+
+                if torch.all(last_token_id == eos_token_id):
+                    break
+
                 outputs.append(last_token_id)
-                inputs = last_token_id
+                seq_idx += 1
 
         return torch.cat(outputs, dim=-1)
+
+    def greedy_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses greedy search.
+
+        :param input_ids: The input tokens to the model.
+        :param max_length: The maximum length of the sequence to generate.
+        :param pad_token_id: The id of the padding token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param provided_constraints: A list of constraints to use.
+        """
+        return self.generate(
+            inputs=input_ids,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=GreedyAlgorithm(),
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def sample(
+        self,
+        input_ids: torch.LongTensor,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
+        
+        :param: input_ids: The input tokens to the model.
+        :param: temperature: The temperature to use for sampling.
+        :param: top_k: The number of samples to use for top_k sampling.
+        :param: top_p: The probability of using top_p sampling.
+        :param: max_length: The maximum length of the sequence to generate.
+        :param: pad_token_id: The id of the padding token.
+        :param: eos_token_id: The id of the end of sentence token.
+        :param: provided_constraints: A list of constraints to use.
+        :param: model_kwargs: Additional kwargs to pass to the model.
+        """
+
+        return self.generate(
+            inputs=input_ids,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def beam_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def beam_sample(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def group_beam_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def _choose_sample_algorithm(
+        self,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+    ) -> DecodingAlgorithm:
+        if (top_k is not None) and (top_p is not None):
+            raise ValueError("You have to provide only top_k or top_p for sampling")
+        if top_k:
+            return TopKAlgorithm(top_k, temperature)
+        elif top_p:
+            return NucleusAlgorithm(top_p, temperature)
+
+    def _get_constraints(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        eos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+    ) -> List[ABCBloomConstraint]:
+        constraints = []
+        constraints.extend(provided_constraints)
+        if max_new_tokens is not None:
+            constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
+        constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
+        return constraints
+

+ 11 - 11
src/client/remote_model.py

@@ -20,7 +20,7 @@ from src.bloom.model import (
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_generation import RemoteGenerationMixin
 from src.utils.generation_algorithms import DecodingAlgorithm
-from src.utils.generation_constraints import ABConstraint
+from src.utils.generation_constraints import ABCBloomConstraint
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -185,20 +185,20 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
         eos_token_id: Optional[int] = None,
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
-        provided_constraints: List[ABConstraint] = [],
+        provided_constraints: List[ABCBloomConstraint] = [],
         **model_kwargs,
     ) -> torch.Tensor:
         return RemoteGenerationMixin.generate(
             self,
-            inputs,
-            do_sample,
-            temperature,
-            top_k,
-            top_p,
-            eos_token_id,
-            max_new_tokens,
-            decoding_algorithm,
-            provided_constraints,
+            inputs=inputs,
+            do_sample=do_sample,
+            temperature=temperature,
+            top_k=top_k,
+            top_p=top_p,
+            eos_token_id=eos_token_id,
+            max_new_tokens=max_new_tokens,
+            decoding_algorithm=decoding_algorithm,
+            provided_constraints=provided_constraints,
             **model_kwargs,
         )
 

+ 35 - 16
src/utils/generation_algorithms.py

@@ -4,43 +4,64 @@ from abc import ABC
 from typing import Tuple
 
 TokenIds = torch.Tensor
-BatchIds = torch.Tensor
+HypoIds = torch.Tensor
 
 
 class DecodingAlgorithm(ABC):
+    """
+    An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
+    """
     def __init__(self) -> None:
         pass
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
+        """
         pass
 
 
 class GreedyAlgorithm(DecodingAlgorithm):
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
-        return logits.max(-1)[1], torch.arange(logits.size(0))
+    """
+    The simpliest algorithm for decoding. It selects the most probable token.
+    """
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
+        """
+        return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
 
 
-class TopKAlgorithm(DecodingAlgorithm):
+class SamplingAlgorithm(DecodingAlgorithm):
+    def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
+        :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size). 
+        """
+        logits[indices_to_remove] = -float("Inf")
+        probs = torch.softmax(logits / self.temperature, -1)
+        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+
+
+class TopKAlgorithm(SamplingAlgorithm):
     # TODO: Add NumHypos, maxBatchSize
     def __init__(self, top_k: int, temperature: float = 1.0) -> None:
         self.top_k = top_k
         self.temperature = temperature
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
-        logits = logits[:, -1]
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
-        logits[indices_to_remove] = -float("Inf")
-        probs = torch.softmax(logits / self.temperature, -1)
-        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+        return self.sample(logits, indices_to_remove)
 
 
-class NucleusAlgorithm(DecodingAlgorithm):
+class NucleusAlgorithm(SamplingAlgorithm):
     def __init__(self, top_p: float, temperature: float = 1.0) -> None:
         self.top_p = top_p
         self.temperature = temperature
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
-        logits = logits[:, -1]
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.softmax(sorted_logits / self.temperature, -1)
         cumulative_probs = torch.cumsum(probs, dim=-1)
@@ -49,9 +70,7 @@ class NucleusAlgorithm(DecodingAlgorithm):
         sorted_indices_to_remove[..., 0] = False
         indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
         indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
-        logits[indices_to_remove] = -float("Inf")
-        probs = torch.softmax(logits / self.temperature, -1)
-        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+        return self.sample(logits, indices_to_remove)
 
 
 # TODO: In generate function we need to check usage of top_k or sampling algorithm

+ 63 - 20
src/utils/generation_constraints.py

@@ -3,33 +3,76 @@ import torch
 from abc import ABC
 
 
-class ABConstraint(ABC):
+class ABCBloomConstraint(ABC):
+    """
+    Base class of all kind of decoding constraints. It can be used to implement a new constraint.
+    """
     def __init__(self) -> None:
         pass
 
-    def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor,  hypo_ids: torch.Tensor) -> torch.Tensor:
+        """
+        This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
+        :param tokens_id: The token id of the last choosen token.
+        :param logits: The logits from the Bloom model.
+        :param hypo_ids: The hypothesis ids of the last tokens.
+        """
         pass
 
-    def consume_prefix(self, prefix: torch.Tensor) -> None:
-        pass
 
-    def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
-        pass
-    
-    
-class MaxNewTokensConstraint(ABConstraint):
-    def __init__(self, max_new_tokens: int, eos_token_id: int, min_logits: float = -100000) -> None:
+class MaxNewTokensConstraint(ABCBloomConstraint):
+    """
+    Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
+
+    Args:
+        prefix: The prefix of the sequence.
+        max_new_tokens: The maximum number of tokens that can be generated after the prefix.
+        eos_token_id: The id of the end of sentence token.
+        pad_token_id: The id of the padding token.
+        min_logits: The minimum logits that can be generated. Default: -1e6.
+    """
+    def __init__(self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e6) -> None:
         self.max_new_tokens = max_new_tokens
-        self.current_generated_tokens = 0
+        self.current_generated_tokens = None
+        self.eos_token_id = eos_token_id
+        self.min_logits = min_logits
+
+        self.current_generated_tokens = -(prefix == pad_token_id).sum(-1)
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        if tokens_id is not None:
+            self.current_generated_tokens += 1
+
+        mask = (self.current_generated_tokens > self.max_new_tokens).unsqueeze(1)
+        logits += self.min_logits * mask
+        logits[mask[:, 0], self.eos_token_id] = 0
+        return logits
+
+
+class EosConstraint(ABCBloomConstraint):
+    """
+    This constrained repeats EOS token if it was generated on the previous step.
+    Args:
+        prefix: The prefix of the sequence.
+        eos_token_id: The id of the end of sentence token.
+        pad_token_id: The id of the padding token.
+        min_logits: The minimum logits that can be generated. Default: -1e6.
+    """
+    def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e6) -> None:
         self.eos_token_id = eos_token_id
         self.min_logits = min_logits
-    
-    def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
-        self.current_generated_tokens += 1
-
-    def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
-        if self.current_generated_tokens > self.max_new_tokens:
-            mask = torch.zeros_like(logits)
-            mask[..., self.eos_token_id] = 1
-            logits += self.min_logits * (1 - mask)
+        self.past_tokens = None
+
+        self.wait_until_starting = (prefix == pad_token_id).sum(-1).unsqueeze(1)
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        if self.past_tokens is not None:
+            mask = ((self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id))
+            logits += self.min_logits * mask
+            logits[mask[:, 0], self.eos_token_id] = 0
+        
+        if tokens_id is not None:
+            self.past_tokens = tokens_id
+            self.wait_until_starting -= 1
+
         return logits