123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import dataclasses
- from contextlib import contextmanager
- from typing import Optional
- import torch
- import torch.nn as nn
- from hivemind import get_logger
- from transformers import PretrainedConfig
- from petals.utils.misc import DUMMY
- logger = get_logger(__name__)
- @dataclasses.dataclass
- class PTuneConfig:
- pre_seq_len: int = 0 # a number of tokens for prompt tuning.
- tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
- class PTuneMixin:
- _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
- def init_prompts(self, config: PretrainedConfig) -> None:
- if config.tuning_mode and "ptune" in config.tuning_mode:
- assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
- self.pre_seq_len = config.pre_seq_len
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
- with force_non_empty_weights():
- # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
- self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
- if config.tuning_mode == "deep_ptune":
- self.intermediate_prompt_embeddings = nn.Embedding(
- self.pre_seq_len,
- config.num_hidden_layers * config.hidden_size,
- # ^-- TODO: should be num_hidden_layers - 1
- dtype=torch.float32,
- )
- elif config.tuning_mode:
- raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
- def get_prompt(self, batch_size):
- prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
- prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
- prompts = self.prompt_embeddings(prefix_tokens)
- if self.config.tuning_mode == "deep_ptune":
- intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
- intermediate_prompts = intermediate_prompts.view(
- batch_size,
- self.pre_seq_len,
- self.config.num_hidden_layers,
- self.config.hidden_size
- # TODO: should be num_hidden_layers - 1
- )
- intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
- else:
- intermediate_prompts = DUMMY
- dtype = self.word_embeddings.weight.dtype
- return prompts.to(dtype), intermediate_prompts.to(dtype)
- _original_register_parameter = nn.Module.register_parameter
- @contextmanager
- def force_non_empty_weights():
- """
- This context manager allows to bypass the accelerate.init_empty_weights() context manager
- (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
- The transformers library should replace all meta tensors by empty tensors by itself
- but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
- [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
- """
- try:
- possibly_patched_register_parameter = nn.Module.register_parameter
- nn.Module.register_parameter = _original_register_parameter
- yield
- finally:
- nn.Module.register_parameter = possibly_patched_register_parameter
|