model.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import Optional
  2. import hivemind
  3. import torch
  4. import torch.nn as nn
  5. from hivemind.utils.logging import get_logger
  6. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
  7. from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
  8. from petals.client.from_pretrained import FromPretrainedMixin
  9. from petals.client.lm_head import LMHead
  10. from petals.client.ptune import PTuneMixin
  11. from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
  12. from petals.client.remote_sequential import RemoteSequential
  13. from petals.models.bloom.config import DistributedBloomConfig
  14. logger = get_logger(__name__)
  15. class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
  16. """BloomModel, but all transformer layers are hosted by the swarm"""
  17. _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
  18. _keys_to_ignore_on_load_unexpected = [r"^h\."]
  19. config_class = DistributedBloomConfig
  20. def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
  21. n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
  22. super().__init__(config)
  23. assert len(self.h) == 0
  24. config.num_hidden_layers = n_layer
  25. self.h = RemoteSequential(config, dht=dht)
  26. self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
  27. self.init_prompts(config)
  28. def forward(
  29. self,
  30. input_ids: Optional[torch.LongTensor] = None,
  31. past_key_values: Optional[RemotePastKeyValues] = None,
  32. attention_mask: Optional[torch.Tensor] = None,
  33. head_mask: Optional[torch.LongTensor] = None,
  34. inputs_embeds: Optional[torch.LongTensor] = None,
  35. use_cache: Optional[bool] = None,
  36. output_attentions: Optional[bool] = None,
  37. output_hidden_states: Optional[bool] = None,
  38. return_dict: Optional[bool] = None,
  39. ):
  40. if input_ids is not None and inputs_embeds is not None:
  41. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  42. elif input_ids is not None:
  43. input_shape = input_ids.size()
  44. input_ids = input_ids.view(-1, input_shape[-1])
  45. elif inputs_embeds is not None:
  46. input_shape = inputs_embeds.size()[:-1]
  47. else:
  48. raise ValueError("You have to specify either input_ids or inputs_embeds")
  49. # The causal mask will be added on the server-side
  50. assert (
  51. attention_mask is None or (attention_mask == 1).all()
  52. ), f"Custom attention masks are not supported, {attention_mask=}"
  53. assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
  54. assert use_cache is None or use_cache, f"{use_cache=} is not supported"
  55. assert not output_attentions, f"{output_attentions=} is not supported"
  56. assert not output_hidden_states, f"{output_hidden_states=} is not supported"
  57. assert return_dict is None or return_dict, f"{return_dict=} is not supported"
  58. if inputs_embeds is None:
  59. inputs_embeds = self.word_embeddings(input_ids)
  60. if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
  61. batch_size = inputs_embeds.shape[0]
  62. prompts, intermediate_prompts = self.get_prompt(batch_size)
  63. inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
  64. else:
  65. prompts = intermediate_prompts = None
  66. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  67. output_shape = input_shape + (hidden_states.size(-1),)
  68. hidden_states = self.h(
  69. hidden_states,
  70. prompts=intermediate_prompts,
  71. hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
  72. )
  73. # Remove prefix
  74. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  75. hidden_states = hidden_states[:, self.pre_seq_len :]
  76. # Add last hidden state
  77. hidden_states = self.ln_f(hidden_states)
  78. hidden_states = hidden_states.view(output_shape)
  79. return BaseModelOutputWithPastAndCrossAttentions(
  80. last_hidden_state=hidden_states,
  81. past_key_values=RemotePastKeyValues(),
  82. hidden_states=None,
  83. attentions=None,
  84. )
  85. class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
  86. _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
  87. _keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
  88. _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
  89. config_class = DistributedBloomConfig
  90. def __init__(self, config: DistributedBloomConfig):
  91. BloomPreTrainedModel.__init__(self, config)
  92. self.transformer = DistributedBloomModel(config)
  93. self.lm_head = LMHead(config)
  94. # Initialize weights and apply final processing
  95. self.post_init()
  96. def get_output_embeddings(self):
  97. return self.lm_head
  98. class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
  99. _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
  100. _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
  101. config_class = DistributedBloomConfig
  102. def __init__(self, config: DistributedBloomConfig):
  103. BloomPreTrainedModel.__init__(self, config)
  104. self.num_labels = config.num_labels
  105. self.transformer = DistributedBloomModel(config)
  106. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  107. # Initialize weights and apply final processing
  108. self.post_init()