Ver Fonte

Add GenerationMixin class (#29)

Add generation abstraction, that's using inference_session.
Added modes:
- Greedy, top-k/top-p sampling
- Multibatch generation
- Constraint abstraction
In the future, will add prefix-tuned generation, beam-search and more hf-like stuff.
Artem Chumachenko há 3 anos atrás
pai
commit
6ee942e915

+ 5 - 2
src/client/remote_block.py

@@ -61,12 +61,15 @@ class RemoteTransformerBlockInferenceSession:
 
     @classmethod
     async def _create(
-        cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
+        cls,
+        remote_module: RemoteTransformerBlock,
+        timeout: Optional[float] = None,
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
         outputs_stream = await remote_module.stub.rpc_inference(
-            cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
+            cls._read_inputs_from_queue(inputs_queue, timeout),
+            timeout=timeout,
         )
         return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
 

+ 240 - 0
src/client/remote_generation.py

@@ -0,0 +1,240 @@
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
+from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
+
+
+class RemoteGenerationMixin:
+    """
+    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
+    The class exposes can be used for:
+        - *greedy decoding*.
+        - *multinomial sampling*.
+
+    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
+    """
+
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        do_sample: Optional[bool] = None,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        bos_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        decoding_algorithm: Optional[DecodingAlgorithm] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head.
+
+        :param inputs: The input tokens to the model.
+        :param do_sample: Whether to sample from the model predictions or take the argmax.
+        :param temperature: The temperature to use for sampling.
+        :param top_k: The number of results to return.
+        :param top_p: The cumulative probability of results to return.
+        :param bos_token_id: The id of the beginning of sentence token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param pad_token_id: The id of the padding token.
+        :param max_new_tokens: The maximum number of tokens to generate.
+        :param decoding_algorithm: The decoding algorithm to use.
+        :param provided_constraints: A list of constraints to use.
+        :param model_kwargs: Additional arguments to pass to the model.
+        """
+
+        assert (
+            model_kwargs.get("logits_processor", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
+        assert (
+            model_kwargs.get("logits_wrapper", None) is None
+        ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
+        assert (
+            model_kwargs.get("stopping_criteria", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+
+        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+
+        if inputs is None:
+            assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
+            inputs = torch.tensor([[bos_token_id]])
+
+        if decoding_algorithm is None:
+            if do_sample:
+                decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
+            else:
+                decoding_algorithm = GreedyAlgorithm()
+
+        constraints = self._get_constraints(
+            inputs=inputs,
+            eos_token_id=eos_token_id,
+            pad_token_id=pad_token_id,
+            max_new_tokens=max_new_tokens,
+            provided_constraints=provided_constraints,
+        )
+
+        with self.transformer.h.inference_session() as sess:
+            outputs = []
+            if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
+                outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
+            else:
+                outputs += [inputs]
+            last_token_id = None
+            seq_idx = outputs[0].size(1)
+            hypo_ids = torch.arange(outputs[0].size(0))
+            while True:
+                embs = self.transformer.word_embeddings(outputs[-1])
+                embs = self.transformer.word_embeddings_layernorm(embs)
+                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = self.transformer.ln_f(hidden_state)
+                lm_logits = self.lm_head(hidden_state)
+
+                for constraint in constraints:
+                    lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
+                last_token_id, hypo_ids = decoding_algorithm(lm_logits)
+                if seq_idx < inputs.size(1):  # TODO: why is it not a constraint?
+                    pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
+                    last_token_id = (~pad_token_mask) * inputs[
+                        :, seq_idx : seq_idx + 1
+                    ] + pad_token_mask * last_token_id
+
+                if torch.all(last_token_id == eos_token_id):
+                    break
+
+                outputs.append(last_token_id)
+                seq_idx += 1
+
+        return torch.cat(outputs, dim=-1)
+
+    def greedy_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses greedy search.
+
+        :param input_ids: The input tokens to the model.
+        :param max_length: The maximum length of the sequence to generate.
+        :param pad_token_id: The id of the padding token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param provided_constraints: A list of constraints to use.
+        """
+        return self.generate(
+            inputs=input_ids,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=GreedyAlgorithm(),
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def sample(
+        self,
+        input_ids: torch.LongTensor,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
+
+        :param: input_ids: The input tokens to the model.
+        :param: temperature: The temperature to use for sampling.
+        :param: top_k: The number of samples to use for top_k sampling.
+        :param: top_p: The probability of using top_p sampling.
+        :param: max_length: The maximum length of the sequence to generate.
+        :param: pad_token_id: The id of the padding token.
+        :param: eos_token_id: The id of the end of sentence token.
+        :param: provided_constraints: A list of constraints to use.
+        :param: model_kwargs: Additional kwargs to pass to the model.
+        """
+
+        return self.generate(
+            inputs=input_ids,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
+    def beam_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def beam_sample(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def group_beam_search(
+        self,
+        input_ids: torch.LongTensor,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.LongTensor:
+        raise NotImplementedError
+
+    def _choose_sample_algorithm(
+        self,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+    ) -> DecodingAlgorithm:
+        if (top_k is not None) and (top_p is not None):
+            raise ValueError("You have to provide only top_k or top_p for sampling")
+        if top_k:
+            return TopKAlgorithm(top_k, temperature)
+        elif top_p:
+            return NucleusAlgorithm(top_p, temperature)
+
+    def _get_constraints(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        eos_token_id: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+    ) -> List[ABCBloomConstraint]:
+        constraints = []
+        constraints.extend(provided_constraints)
+        if max_new_tokens is not None:
+            constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
+        constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
+        return constraints

+ 33 - 3
src/client/remote_model.py

@@ -1,6 +1,6 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
 
 import hivemind
 import torch
@@ -15,7 +15,10 @@ from src.bloom.model import (
     BloomPreTrainedModel,
     LMHead,
 )
+from src.client.remote_generation import RemoteGenerationMixin
 from src.client.remote_sequential import RemoteSequential
+from src.utils.generation_algorithms import DecodingAlgorithm
+from src.utils.generation_constraints import ABCBloomConstraint
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -137,8 +140,8 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return transformer_outputs
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM):
-    """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
+    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     config_class = DistributedBloomConfig
 
@@ -171,6 +174,33 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
             self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
             self.lm_head.bias[...] = new_lm_head.bias
 
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        do_sample: Optional[bool] = None,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        eos_token_id: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        decoding_algorithm: Optional[DecodingAlgorithm] = None,
+        provided_constraints: List[ABCBloomConstraint] = [],
+        **model_kwargs,
+    ) -> torch.Tensor:
+        return RemoteGenerationMixin.generate(
+            self,
+            inputs=inputs,
+            do_sample=do_sample,
+            temperature=temperature,
+            top_k=top_k,
+            top_p=top_p,
+            eos_token_id=eos_token_id,
+            max_new_tokens=max_new_tokens,
+            decoding_algorithm=decoding_algorithm,
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
+
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig

+ 14 - 6
src/server/handler.py

@@ -26,7 +26,9 @@ class TransformerConnectionHandler(ConnectionHandler):
             assert isinstance(module_backend, TransformerBackend)
 
     async def rpc_inference(
-        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+        self,
+        requests: AsyncIterator[runtime_pb2.ExpertRequest],
+        context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         try:
@@ -35,17 +37,21 @@ class TransformerConnectionHandler(ConnectionHandler):
             requested_uids = self._check_header(request)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, prefix_length]
+            batch_size = request.tensors[0].size[0] if request.tensors else 1
+
+            cache_metadata = torch.tensor(
+                [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
+            )  # [cache_handle, prefix_length]
             prefix_length = 0
 
-            async with self._allocate_caches(requested_backends) as cache_handles:
+            async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
-                        cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
+                        cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
                         assert (
                             len(hidden_states) == 1 and hidden_states[0].ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
@@ -213,7 +219,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         return tuple(uids)
 
     @contextlib.asynccontextmanager
-    async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
+    async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
             handles = []
@@ -221,7 +227,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
 
-                cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
+                cache_descriptor = TensorDescriptor(
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))

+ 0 - 0
src/utils/__init__.py


+ 78 - 0
src/utils/generation_algorithms.py

@@ -0,0 +1,78 @@
+from abc import ABC
+from typing import Tuple
+
+import torch
+
+TokenIds = torch.Tensor
+HypoIds = torch.Tensor
+
+
+class DecodingAlgorithm(ABC):
+    """
+    An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
+        """
+        pass
+
+
+class GreedyAlgorithm(DecodingAlgorithm):
+    """
+    The simpliest algorithm for decoding. It selects the most probable token.
+    """
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
+        """
+        return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
+
+
+class SamplingAlgorithm(DecodingAlgorithm):
+    def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        """
+        :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
+        :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
+        """
+        logits[indices_to_remove] = -float("Inf")
+        probs = torch.softmax(logits / self.temperature, -1)
+        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+
+
+class TopKAlgorithm(SamplingAlgorithm):
+    # TODO: Add NumHypos, maxBatchSize
+    def __init__(self, top_k: int, temperature: float = 1.0) -> None:
+        self.top_k = top_k
+        self.temperature = temperature
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
+        return self.sample(logits, indices_to_remove)
+
+
+class NucleusAlgorithm(SamplingAlgorithm):
+    def __init__(self, top_p: float, temperature: float = 1.0) -> None:
+        self.top_p = top_p
+        self.temperature = temperature
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        probs = torch.softmax(sorted_logits / self.temperature, -1)
+        cumulative_probs = torch.cumsum(probs, dim=-1)
+        sorted_indices_to_remove = cumulative_probs > self.top_p
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = False
+        indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
+        indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
+        return self.sample(logits, indices_to_remove)
+
+
+# TODO: In generate function we need to check usage of top_k or sampling algorithm

+ 84 - 0
src/utils/generation_constraints.py

@@ -0,0 +1,84 @@
+from abc import ABC
+
+import torch
+
+
+class ABCBloomConstraint(ABC):
+    """
+    Base class of all kind of decoding constraints. It can be used to implement a new constraint.
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        """
+        This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
+        :param tokens_id: The token id of the last choosen token.
+        :param logits: The logits from the Bloom model.
+        :param hypo_ids: The hypothesis ids of the last tokens.
+        """
+        pass
+
+
+class MaxNewTokensConstraint(ABCBloomConstraint):
+    """
+    Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
+
+    Args:
+        prefix: The prefix of the sequence.
+        max_new_tokens: The maximum number of tokens that can be generated after the prefix.
+        eos_token_id: The id of the end of sentence token.
+        pad_token_id: The id of the padding token.
+        min_logits: The minimum logits that can be generated. Default: -1e6.
+    """
+
+    def __init__(
+        self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
+    ) -> None:
+        self.max_new_tokens = max_new_tokens
+        self.current_generated_tokens = None
+        self.eos_token_id = eos_token_id
+        self.min_logits = min_logits
+
+        max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
+        self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        if tokens_id is not None:
+            self.current_generated_tokens += 1
+
+        mask = self.current_generated_tokens >= self.max_new_tokens
+        logits += self.min_logits * mask
+        logits[mask[:, 0], self.eos_token_id] = 0
+        return logits
+
+
+class EosConstraint(ABCBloomConstraint):
+    """
+    This constrained repeats EOS token if it was generated on the previous step.
+    Args:
+        prefix: The prefix of the sequence.
+        eos_token_id: The id of the end of sentence token.
+        pad_token_id: The id of the padding token.
+        min_logits: The minimum logits that can be generated. Default: -1e6.
+    """
+
+    def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
+        self.eos_token_id = eos_token_id
+        self.min_logits = min_logits
+        self.past_tokens = None
+
+        self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
+
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+        if self.past_tokens is not None:
+            mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
+            logits += self.min_logits * mask
+            logits[mask[:, 0], self.eos_token_id] = 0
+
+        if tokens_id is not None:
+            self.past_tokens = tokens_id
+            self.wait_until_starting -= 1
+
+        return logits