|
@@ -449,7 +449,7 @@ class LMHead(nn.Module):
|
|
|
else:
|
|
|
# Switch dtype in case word_embeddings are fp16/bf16
|
|
|
hidden_states = hidden_states.to(word_embeddings.dtype)
|
|
|
- lm_logits = F.linear(hidden_states, word_embeddings).float()
|
|
|
+ lm_logits = F.linear(hidden_states, word_embeddings)
|
|
|
return lm_logits
|
|
|
|
|
|
def chunked_forward(self, hidden_states):
|