Przeglądaj źródła

Merge pull request #19 from learning-at-home/lm_head

add a modified LM head
Dmitry Baranchuk 3 lat temu
rodzic
commit
ac7df18dfa
2 zmienionych plików z 97 dodań i 20 usunięć
  1. 79 16
      src/bloom/model.py
  2. 18 4
      src/client/remote_model.py

+ 79 - 16
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-from typing import Tuple
+from typing import Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -165,19 +165,12 @@ class BloomModel(BloomPreTrainedModel):
         # Initialize weights and apply final processing
         self.post_init()
 
-        # Forbid accumulate grads for embeddings and layernorm
-        self.set_requires_grad(False)
-
     def get_input_embeddings(self):
         return self.word_embeddings
 
     def set_input_embeddings(self, new_embeddings):
         self.word_embeddings = new_embeddings
 
-    def set_requires_grad(self, value):
-        for p in self.parameters():
-            p.requires_grad = value
-
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(
         processor_class=_TOKENIZER_FOR_DOC,
@@ -319,14 +312,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
     def __init__(self, config):
         super().__init__(config)
         self.transformer = BloomModel(config)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
         # Initialize weights and apply final processing
         self.post_init()
 
     def get_output_embeddings(self):
-        return self.transformer.word_embeddings
+        return self.lm_head
 
     def set_output_embeddings(self, new_embeddings):
-        self.transformer.word_embeddings.weight = new_embeddings.weight
+        self.lm_head = new_embeddings
 
     def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
         # only last token for inputs_ids if past is defined in kwargs
@@ -359,7 +354,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
         output_type=CausalLMOutputWithCrossAttentions,
         config_class=_CONFIG_FOR_DOC,
     )
-    def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
         r"""
         labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
             Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -367,12 +375,22 @@ class BloomForCausalLM(BloomPreTrainedModel):
             are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
         """
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-        transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
-        word_embeddings = self.transformer.word_embeddings.weight
 
-        # Switch dtype in case word_embeddings are fp16/bf16
-        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
-        lm_logits = F.linear(hidden_states, word_embeddings).float()
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = transformer_outputs[0]
+
+        lm_logits = self.lm_head(hidden_states)
 
         loss = None
         if labels is not None:
@@ -406,3 +424,48 @@ class BloomForCausalLM(BloomPreTrainedModel):
             tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
             for layer_past in past
         )
+
+
+@add_start_docstrings(
+    """
+    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
+    """,
+    BLOOM_START_DOCSTRING,
+)
+class LMHead(nn.Module):
+    def __init__(self, config, word_embeddings: nn.Embedding):
+        super().__init__()
+        self.word_embeddings = word_embeddings
+        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+    def forward(self, hidden_states):
+        word_embeddings = self.word_embeddings.weight
+        
+        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
+        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
+            word_embeddings.device.type == 'cpu':
+            lm_logits = self.chunked_forward(hidden_states)
+        else:
+            # Switch dtype in case word_embeddings are fp16/bf16
+            hidden_states = hidden_states.to(word_embeddings.dtype)
+            lm_logits = F.linear(hidden_states, word_embeddings).float()
+        return lm_logits
+
+    def chunked_forward(self, hidden_states):
+        """ Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. 
+            chunk_size: provides trade-off between efficiency and extra memory consumption. 
+        """
+        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+
+        word_embeddings = self.word_embeddings.weight
+        num_embeddings = self.word_embeddings.num_embeddings
+
+        hidden_states = hidden_states.float()    
+        output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
+
+        for i in range(0, num_embeddings, self.chunk_size):
+            chunk = word_embeddings[i: i + self.chunk_size].float()
+            output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
+        return output

+ 18 - 4
src/client/remote_model.py

@@ -1,12 +1,11 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Tuple, Union
+from typing import Optional, Tuple
 
 import hivemind
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import get_logger, use_hivemind_log_handler
 
-from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -23,6 +22,7 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU 
 
 
 class DistributedBloomModel(BloomModel):
@@ -45,6 +45,13 @@ class DistributedBloomModel(BloomModel):
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         self.h = RemoteSequential(config, dht, config.dht_prefix)
+    
+        # Forbid accumulate grads for embeddings and layernorm
+        self.set_requires_grad(False)
+
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad = value
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
@@ -54,5 +61,12 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
         self.transformer = DistributedBloomModel(config)
+        self.lm_head = LMHead(config, self.transformer.word_embeddings)
         # Initialize weights and apply final processing
         self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head.word_embeddings
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.word_embeddings.weight = new_embeddings.weight