瀏覽代碼

revert changes

dbaranchuk 3 年之前
父節點
當前提交
682c6711df
共有 1 個文件被更改,包括 6 次插入12 次删除
  1. 6 12
      src/bloom/model.py

+ 6 - 12
src/bloom/model.py

@@ -237,6 +237,11 @@ class BloomModel(BloomPreTrainedModel):
         all_self_attentions = () if output_attentions else None
         all_hidden_states = () if output_hidden_states else None
 
+        # Compute alibi tensor: check build_alibi_tensor documentation
+        current_sequence_length = hidden_states.shape[1]
+        if past_key_values and past_key_values[0]:
+            current_sequence_length += past_key_values[0][0].shape[1]
+
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
             if output_hidden_states:
@@ -507,22 +512,11 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
     def __init__(self, config):
         super().__init__(config)
         self.num_labels = config.num_labels
-
-        config.pre_seq_len = 16
-        config.prompt_tuning_mode = 'deep'
-
-        if config.pre_seq_len > 0:
-            self.transformer = BloomPrefix(config)
-        else:
-            self.transformer = BloomModel(config)
-
-        self.pooled_dropout = nn.Dropout(0.0)#pooled_dropout)
+        self.transformer = BloomModel(config)
         self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
 
         # Initialize weights and apply final processing
         self.post_init()
-        if hasattr(self.transformer, 'intermediate_prompt_embeddings'):
-            self.transformer.intermediate_prompt_embeddings.weight.data.zero_()
 
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(