|
@@ -436,8 +436,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
@add_start_docstrings(
|
|
|
"""
|
|
|
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
|
|
- embeddings. It reduces initial memory consumption which might be crucial for large dictionaries. In addition, it provides
|
|
|
- an effcient way to perform half-precision calculations on CPU.
|
|
|
+ embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
|
|
+ In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
|
|
""",
|
|
|
BLOOM_START_DOCSTRING,
|
|
|
)
|
|
@@ -449,9 +449,10 @@ class LMHead(nn.Module):
|
|
|
|
|
|
def forward(self, hidden_states):
|
|
|
word_embeddings = self.word_embeddings.weight
|
|
|
+
|
|
|
+ # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
|
|
|
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
|
|
|
word_embeddings.device.type == 'cpu':
|
|
|
- # We use 'chunked_forward' only for half-precision computations on CPU.
|
|
|
lm_logits = self.chunked_forward(hidden_states)
|
|
|
else:
|
|
|
# Switch dtype in case word_embeddings are fp16/bf16
|