Explorar o código

Don't cast logits to float32 on GPU

Aleksandr Borzunov %!s(int64=2) %!d(string=hai) anos
pai
achega
86d08bf515
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      src/bloom/model.py

+ 1 - 1
src/bloom/model.py

@@ -449,7 +449,7 @@ class LMHead(nn.Module):
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings).float()
+            lm_logits = F.linear(hidden_states, word_embeddings)
         return lm_logits
 
     def chunked_forward(self, hidden_states):