peft_utils.py 512 B

123456789101112131415161718
  1. """
  2. Generalized parameter-efficient finetuning modules that support deep prompts and several types of adapters.
  3. Designed to be used on both client and server side.
  4. """
  5. import torch.nn as nn
  6. from src.utils.misc import DUMMY
  7. class GenericPEFTModule(nn.Module):
  8. """Container for PEFT parameters for a single transformer block, supports multiple modes"""
  9. def __init__(self, hidden_size: int):
  10. super().__init__()
  11. self.hidden_size = hidden_size
  12. self.prompts = nn.Parameter(DUMMY)