dbaranchuk 3 年之前
父節點
當前提交
6bffeff0a1
共有 2 個文件被更改,包括 13 次插入10 次删除
  1. 12 9
      src/bloom/model.py
  2. 1 1
      src/client/remote_model.py

+ 12 - 9
src/bloom/model.py

@@ -319,7 +319,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
     def __init__(self, config):
         super().__init__(config)
         self.transformer = BloomModel(config)
-        self.lm_head = LMHeadForCausalLM(config)
+        self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
 
         # Initialize weights and apply final processing
         self.post_init()
@@ -328,7 +328,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
         return self.lm_head.word_embeddings
 
     def set_output_embeddings(self, new_embeddings):
-        self.lm_head.word_embeddings = new_embeddings.weight
+        self.lm_head.word_embeddings.weight = new_embeddings.weight
 
     def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
         # only last token for inputs_ids if past is defined in kwargs
@@ -418,18 +418,19 @@ class BloomForCausalLM(BloomPreTrainedModel):
 class LMHeadForCausalLM(nn.Module):
     def __init__(self, config, word_embeddings: nn.Embedding):
         super().__init__()
-        self.word_embeddings = word_embeddings.weight
+        self.word_embeddings = word_embeddings
         self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
 
     def forward(self, hidden_states):
-        if self.word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
-            'cpu' in self.word_embeddings.device:
+        word_embeddings = self.word_embeddings.weight
+        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
+            word_embeddings.device.type == 'cpu':
             # We use 'chunked_forward' only for half-precision computations on CPU.
             lm_logits = self.chunked_forward(hidden_states)
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
-            hidden_states = hidden_states.to(self.word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, self.word_embeddings).float()
+            hidden_states = hidden_states.to(word_embeddings.dtype)
+            lm_logits = F.linear(hidden_states, word_embeddings).float()
         return lm_logits
 
     def chunked_forward(self, hidden_states):
@@ -438,11 +439,13 @@ class LMHeadForCausalLM(nn.Module):
         """
         assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
 
+        word_embeddings = self.word_embeddings.weight
+
         hidden_states = hidden_states.float()
-        num_embeddings = self.word_embeddings.shape[1]        
+        num_embeddings = word_embeddings.shape[0]        
         output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
 
         for i in range(0, num_embeddings, self.chunk_size):
-            chunk = self.word_embeddings[..., i:i+self.chunk_size].float()
+            chunk = word_embeddings[i:i+self.chunk_size].float()
             output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
         return output

+ 1 - 1
src/client/remote_model.py

@@ -54,6 +54,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
         self.transformer = DistributedBloomModel(config)
-        self.lm_head = LMHeadForCausalLM(config)
+        self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
         # Initialize weights and apply final processing
         self.post_init()