Browse Source

Keep prompts in float32, cast where necessary

Aleksandr Borzunov 2 năm trước cách đây
mục cha
commit
e520c4781e
1 tập tin đã thay đổi với 11 bổ sung3 xóa
  1. 11 3
      src/petals/client/remote_model.py

+ 11 - 3
src/petals/client/remote_model.py

@@ -95,12 +95,18 @@ class DistributedBloomModel(BloomModel):
             self.prefix_tokens = torch.arange(self.pre_seq_len).long()
 
             with force_non_empty_weights():
-                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+                if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
+                    logger.info(
+                        "Prompt embeddings and their optimizer statistics will be kept in float32 "
+                        "to increase ptune quality"
+                    )
+                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
                 if config.tuning_mode == "deep_ptune":
                     self.intermediate_prompt_embeddings = nn.Embedding(
                         self.pre_seq_len,
-                        config.num_hidden_layers * config.hidden_size
+                        config.num_hidden_layers * config.hidden_size,
                         # ^-- TODO: should be num_hidden_layers - 1
+                        dtype=torch.float32,
                     )
         elif config.tuning_mode:
             raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
@@ -122,7 +128,9 @@ class DistributedBloomModel(BloomModel):
             intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
         else:
             intermediate_prompts = DUMMY
-        return prompts, intermediate_prompts
+
+        dtype = self.word_embeddings.weight.dtype
+        return prompts.to(dtype), intermediate_prompts.to(dtype)
 
     def forward(
         self,