dbaranchuk 3 éve
szülő
commit
5168a3405a
1 módosított fájl, 4 hozzáadás és 3 törlés
  1. 4 3
      src/bloom/model.py

+ 4 - 3
src/bloom/model.py

@@ -436,8 +436,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
 @add_start_docstrings(
     """
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. It reduces initial memory consumption which might be crucial for large dictionaries. In addition, it provides
-    an effcient way to perform half-precision calculations on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
@@ -449,9 +449,10 @@ class LMHead(nn.Module):
 
 
     def forward(self, hidden_states):
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
         word_embeddings = self.word_embeddings.weight
+        
+        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
         if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
         if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
             word_embeddings.device.type == 'cpu':
             word_embeddings.device.type == 'cpu':
-            # We use 'chunked_forward' only for half-precision computations on CPU.
             lm_logits = self.chunked_forward(hidden_states)
             lm_logits = self.chunked_forward(hidden_states)
         else:
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             # Switch dtype in case word_embeddings are fp16/bf16