5
0
dbaranchuk 3 жил өмнө
parent
commit
5168a3405a
1 өөрчлөгдсөн 4 нэмэгдсэн , 3 устгасан
  1. 4 3
      src/bloom/model.py

+ 4 - 3
src/bloom/model.py

@@ -436,8 +436,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. It reduces initial memory consumption which might be crucial for large dictionaries. In addition, it provides
-    an effcient way to perform half-precision calculations on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
     """,
     BLOOM_START_DOCSTRING,
 )
@@ -449,9 +449,10 @@ class LMHead(nn.Module):
 
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
+        
+        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
         if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
             word_embeddings.device.type == 'cpu':
-            # We use 'chunked_forward' only for half-precision computations on CPU.
             lm_logits = self.chunked_forward(hidden_states)
         else:
             # Switch dtype in case word_embeddings are fp16/bf16