瀏覽代碼

fix comments

dbaranchuk 3 年之前
父節點
當前提交
5168a3405a
共有 1 個文件被更改,包括 4 次插入3 次删除
  1. 4 3
      src/bloom/model.py

+ 4 - 3
src/bloom/model.py

@@ -436,8 +436,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
 @add_start_docstrings(
     """
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. It reduces initial memory consumption which might be crucial for large dictionaries. In addition, it provides
-    an effcient way to perform half-precision calculations on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
@@ -449,9 +449,10 @@ class LMHead(nn.Module):
 
 
     def forward(self, hidden_states):
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
         word_embeddings = self.word_embeddings.weight
+        
+        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
         if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
         if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
             word_embeddings.device.type == 'cpu':
             word_embeddings.device.type == 'cpu':
-            # We use 'chunked_forward' only for half-precision computations on CPU.
             lm_logits = self.chunked_forward(hidden_states)
             lm_logits = self.chunked_forward(hidden_states)
         else:
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             # Switch dtype in case word_embeddings are fp16/bf16