Explorar el Código

Don't cast logits to float32 on GPU

Aleksandr Borzunov hace 2 años
padre
commit
86d08bf515
Se han modificado 1 ficheros con 1 adiciones y 1 borrados
  1. 1 1
      src/bloom/model.py

+ 1 - 1
src/bloom/model.py

@@ -449,7 +449,7 @@ class LMHead(nn.Module):
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings).float()
+            lm_logits = F.linear(hidden_states, word_embeddings)
         return lm_logits
 
     def chunked_forward(self, hidden_states):