Browse Source

Remove excess .float() cast

Aleksandr Borzunov 2 years ago
parent
commit
d8dac556a6
1 changed files with 1 additions and 1 deletions
  1. 1 1
      src/client/remote_model.py

+ 1 - 1
src/client/remote_model.py

@@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel):
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
 
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode: