Forráskód Böngészése

LM head for CausalLM & chunked forward

dbaranchuk 3 éve
szülő
commit
df42822f26
2 módosított fájl, 48 hozzáadás és 7 törlés
  1. 47 7
      src/bloom/model.py
  2. 1 0
      src/client/remote_model.py

+ 47 - 7
src/bloom/model.py

@@ -319,14 +319,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
     def __init__(self, config):
         super().__init__(config)
         self.transformer = BloomModel(config)
+        self.lm_head = LMHeadForCausalLM(config)
+
         # Initialize weights and apply final processing
         self.post_init()
 
     def get_output_embeddings(self):
-        return self.transformer.word_embeddings
+        return self.lm_head.word_embeddings
 
     def set_output_embeddings(self, new_embeddings):
-        self.transformer.word_embeddings.weight = new_embeddings.weight
+        self.lm_head.word_embeddings = new_embeddings.weight
 
     def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
         # only last token for inputs_ids if past is defined in kwargs
@@ -368,11 +370,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
         """
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
-        word_embeddings = self.transformer.word_embeddings.weight
-
-        # Switch dtype in case word_embeddings are fp16/bf16
-        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
-        lm_logits = F.linear(hidden_states, word_embeddings).float()
+        hidden_states = transformer_outputs[0]
+        lm_logits = self.lm_head(hidden_states)
 
         loss = None
         if labels is not None:
@@ -406,3 +405,44 @@ class BloomForCausalLM(BloomPreTrainedModel):
             tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
             for layer_past in past
         )
+
+
+@add_start_docstrings(
+    """
+    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
+    embeddings. It reduces initial memory consumption which might be crucial for large dictionaries. In addition, it provides
+    an effcient way to perform half-precision calculations on CPU.  
+    """,
+    BLOOM_START_DOCSTRING,
+)
+class LMHeadForCausalLM(nn.Module):
+    def __init__(self, config, word_embeddings: nn.Embedding):
+        super().__init__()
+        self.word_embeddings = word_embeddings.weight
+        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+    def forward(self, hidden_states):
+        if self.word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
+            'cpu' in self.word_embeddings.device:
+            # We use 'chunked_forward' only for half-precision computations on CPU.
+            lm_logits = self.chunked_forward(hidden_states)
+        else:
+            # Switch dtype in case word_embeddings are fp16/bf16
+            hidden_states = hidden_states.to(self.word_embeddings.dtype)
+            lm_logits = F.linear(hidden_states, self.word_embeddings).float()
+        return lm_logits
+
+    def chunked_forward(self, hidden_states):
+        """ Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. 
+            chunk_size: provides trade-off between efficiency and extra memory consumption. 
+        """
+        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+
+        hidden_states = hidden_states.float()
+        num_embeddings = self.word_embeddings.shape[1]        
+        output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
+
+        for i in range(0, num_embeddings, self.chunk_size):
+            chunk = self.word_embeddings[..., i:i+self.chunk_size].float()
+            output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
+        return output

+ 1 - 0
src/client/remote_model.py

@@ -23,6 +23,7 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU 
 
 
 class DistributedBloomModel(BloomModel):