Browse Source

preliminary shallow prompt tuning

dbaranchuk 3 years ago
parent
commit
afce0621b5
3 changed files with 624 additions and 5 deletions
  1. 402 0
      notebooks/prompt_tuning_example.ipynb
  2. 134 2
      src/bloom/model.py
  3. 88 3
      src/client/remote_model.py

File diff suppressed because it is too large
+ 402 - 0
notebooks/prompt_tuning_example.ipynb


+ 134 - 2
src/bloom/model.py

@@ -10,10 +10,15 @@ import torch.nn.functional as F
 import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
-from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, LayerNorm
 from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
                                      add_start_docstrings_to_model_forward)
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    SequenceClassifierOutputWithPast,
+    TokenClassifierOutput,
+)
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.utils import logging
@@ -469,3 +474,130 @@ class LMHead(nn.Module):
             chunk = word_embeddings[i: i + self.chunk_size].float()
             output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
         return output
+
+
+@add_start_docstrings(
+    """
+    The Bloom Model transformer with a sequence classification head on top (linear layer).
+    [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+    (e.g. GPT-1) do.
+    Since it does classification on the last token, it requires to know the position of the last token. If a
+    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+    each row of the batch).
+    """,
+    BLOOM_START_DOCSTRING,
+)
+class BloomForSequenceClassification(BloomPreTrainedModel):
+    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.transformer = BloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        processor_class=_TOKENIZER_FOR_DOC,
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=SequenceClassifierOutputWithPast,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = transformer_outputs[0]
+        logits = self.score(hidden_states)
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]
+        else:
+            batch_size = inputs_embeds.shape[0]
+
+        if self.config.pad_token_id is None and batch_size != 1:
+            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+        if self.config.pad_token_id is None:
+            sequence_lengths = -1
+        else:
+            if input_ids is not None:
+                sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+            else:
+                sequence_lengths = -1
+                logger.warning(
+                    f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+                    "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+                )
+            
+        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+        loss = None
+        if labels is not None:
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(pooled_logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(pooled_logits, labels)
+        if not return_dict:
+            output = (pooled_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutputWithPast(
+            loss=loss,
+            logits=pooled_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )

+ 88 - 3
src/client/remote_model.py

@@ -2,10 +2,13 @@
 import os
 from typing import Optional, Tuple
 
+import torch
+import torch.nn as nn
+
 import hivemind
 from hivemind import get_logger, use_hivemind_log_handler
 
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -22,7 +25,8 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU 
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
+    num_prefix_tokens: int = 0 # a number of tokens for prompt tuning. 
 
 
 class DistributedBloomModel(BloomModel):
@@ -54,14 +58,81 @@ class DistributedBloomModel(BloomModel):
             p.requires_grad = value
 
 
+class DistributedBloomPrefix(DistributedBloomModel):
+    """DistributedBloomModel with prefix tokens for prompt tuning"""
+
+    def __init__(self, config):
+        super().__init__(config)
+        assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
+        self.prefix_length = config.num_prefix_tokens
+
+        self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
+        self.prefix_tokens = torch.arange(self.prefix_length).long()
+
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+        return prompts
+
+    def forward(
+        self, 
+        input_ids: Optional[torch.LongTensor],
+        inputs_embeds: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor],
+        past_key_values=None,
+        position_ids=None,
+        head_mask=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None
+    ):
+        assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
+        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
+        
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+    
+        batch_size = inputs_embeds.shape[0]
+
+        if attention_mask is not None:
+            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
+            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+        prompts = self.get_prompt(batch_size)
+        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        transformer_outputs = super().forward(
+            inputs_embeds=inputs_embeds, 
+            attention_mask=attention_mask, 
+            past_key_values=past_key_values,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict
+        )
+
+        # Remove prefix
+        last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
+        transformer_outputs['last_hidden_state'] = last_hidden_state
+        return transformer_outputs
+
+
 class DistributedBloomForCausalLM(BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
-        self.transformer = DistributedBloomModel(config)
+        if config.num_prefix_tokens > 0:
+            self.transformer = DistributedBloomPrefix(config)
+        else:
+            self.transformer = DistributedBloomModel(config)
         self.lm_head = LMHead(config, self.transformer.word_embeddings)
+
         # Initialize weights and apply final processing
         self.post_init()
 
@@ -70,3 +141,17 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
 
     def set_output_embeddings(self, new_embeddings):
         self.lm_head.word_embeddings.weight = new_embeddings.weight
+
+
+class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        super().__init__(config)
+        if config.num_prefix_tokens > 0:
+            self.transformer = DistributedBloomPrefix(config)
+        else:
+            self.transformer = DistributedBloomModel(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()

Some files were not shown because too many files changed in this diff