Bladeren bron

Don't cast logits to float32 on GPU

Aleksandr Borzunov 2 jaren geleden
bovenliggende
commit
86d08bf515
1 gewijzigde bestanden met toevoegingen van 1 en 1 verwijderingen
  1. 1 1
      src/bloom/model.py

+ 1 - 1
src/bloom/model.py

@@ -449,7 +449,7 @@ class LMHead(nn.Module):
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings).float()
+            lm_logits = F.linear(hidden_states, word_embeddings)
         return lm_logits
 
     def chunked_forward(self, hidden_states):