ptune.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import dataclasses
  2. from contextlib import contextmanager
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from hivemind import get_logger
  7. from transformers import PretrainedConfig
  8. from petals.utils.misc import DUMMY
  9. logger = get_logger(__name__)
  10. @dataclasses.dataclass
  11. class PTuneConfig:
  12. pre_seq_len: int = 0 # a number of tokens for prompt tuning.
  13. tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
  14. class PTuneMixin:
  15. _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
  16. def init_prompts(self, config: PretrainedConfig) -> None:
  17. if config.tuning_mode and "ptune" in config.tuning_mode:
  18. assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
  19. self.pre_seq_len = config.pre_seq_len
  20. self.prefix_tokens = torch.arange(self.pre_seq_len).long()
  21. with force_non_empty_weights():
  22. # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
  23. self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
  24. if config.tuning_mode == "deep_ptune":
  25. self.intermediate_prompt_embeddings = nn.Embedding(
  26. self.pre_seq_len,
  27. config.num_hidden_layers * config.hidden_size,
  28. # ^-- TODO: should be num_hidden_layers - 1
  29. dtype=torch.float32,
  30. )
  31. elif config.tuning_mode:
  32. raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
  33. def get_prompt(self, batch_size):
  34. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
  35. prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
  36. prompts = self.prompt_embeddings(prefix_tokens)
  37. if self.config.tuning_mode == "deep_ptune":
  38. intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
  39. intermediate_prompts = intermediate_prompts.view(
  40. batch_size,
  41. self.pre_seq_len,
  42. self.config.num_hidden_layers,
  43. self.config.hidden_size
  44. # TODO: should be num_hidden_layers - 1
  45. )
  46. intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
  47. else:
  48. intermediate_prompts = DUMMY
  49. dtype = self.word_embeddings.weight.dtype
  50. return prompts.to(dtype), intermediate_prompts.to(dtype)
  51. _original_register_parameter = nn.Module.register_parameter
  52. @contextmanager
  53. def force_non_empty_weights():
  54. """
  55. This context manager allows to bypass the accelerate.init_empty_weights() context manager
  56. (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
  57. The transformers library should replace all meta tensors by empty tensors by itself
  58. but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
  59. [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
  60. """
  61. try:
  62. possibly_patched_register_parameter = nn.Module.register_parameter
  63. nn.Module.register_parameter = _original_register_parameter
  64. yield
  65. finally:
  66. nn.Module.register_parameter = possibly_patched_register_parameter