Kaynağa Gözat

refactoring

dbaranchuk 3 yıl önce
ebeveyn
işleme
79280c4371
2 değiştirilmiş dosya ile 41 ekleme ve 9 silme
  1. 33 7
      src/bloom/model.py
  2. 8 2
      src/client/remote_model.py

+ 33 - 7
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-from typing import Tuple
+from typing import Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -319,16 +319,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
     def __init__(self, config):
         super().__init__(config)
         self.transformer = BloomModel(config)
-        self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
         # Initialize weights and apply final processing
         self.post_init()
 
     def get_output_embeddings(self):
-        return self.lm_head.word_embeddings
+        return self.lm_head
 
     def set_output_embeddings(self, new_embeddings):
-        self.lm_head.word_embeddings.weight = new_embeddings.weight
+        self.lm_head = new_embeddings
 
     def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
         # only last token for inputs_ids if past is defined in kwargs
@@ -361,7 +361,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
         output_type=CausalLMOutputWithCrossAttentions,
         config_class=_CONFIG_FOR_DOC,
     )
-    def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
+    def forward(
+        self,
+        input_ids=None,
+        past_key_values=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
         r"""
         labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
             Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -369,8 +382,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
             are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
         """
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-        transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
+
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
         hidden_states = transformer_outputs[0]
+
         lm_logits = self.lm_head(hidden_states)
 
         loss = None
@@ -415,7 +441,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
     """,
     BLOOM_START_DOCSTRING,
 )
-class LMHeadForCausalLM(nn.Module):
+class LMHead(nn.Module):
     def __init__(self, config, word_embeddings: nn.Embedding):
         super().__init__()
         self.word_embeddings = word_embeddings

+ 8 - 2
src/client/remote_model.py

@@ -5,7 +5,7 @@ from typing import Optional, Tuple
 import hivemind
 from hivemind import get_logger, use_hivemind_log_handler
 
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHeadForCausalLM
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -54,6 +54,12 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
         self.transformer = DistributedBloomModel(config)
-        self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
+        self.lm_head = LMHead(config, self.transformer.word_embeddings)
         # Initialize weights and apply final processing
         self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head.word_embeddings
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.word_embeddings.weight = new_embeddings.weight