|
@@ -40,10 +40,6 @@ class PTuneMixin:
|
|
|
elif config.tuning_mode:
|
|
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
|
|
- def set_requires_grad(self, value):
|
|
|
- for p in self.parameters():
|
|
|
- p.requires_grad = value
|
|
|
-
|
|
|
def get_prompt(self, batch_size):
|
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|