Jelajahi Sumber

Don't cast logits to float32 on GPU

Aleksandr Borzunov 2 tahun lalu
induk
melakukan
86d08bf515
1 mengubah file dengan 1 tambahan dan 1 penghapusan
  1. 1 1
      src/bloom/model.py

+ 1 - 1
src/bloom/model.py

@@ -449,7 +449,7 @@ class LMHead(nn.Module):
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings).float()
+            lm_logits = F.linear(hidden_states, word_embeddings)
         return lm_logits
 
     def chunked_forward(self, hidden_states):