|
@@ -131,7 +131,7 @@ class DistributedBloomModel(BloomModel):
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
- if "ptune" in self.config.tuning_mode:
|
|
|
+ if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
|
else:
|
|
|
hidden_states = self.h(hidden_states)
|