remote_model.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # this code is in active development, interfaces may change
  2. from typing import List, Optional
  3. import hivemind
  4. import torch
  5. import torch.nn as nn
  6. from hivemind import get_logger, use_hivemind_log_handler
  7. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
  8. from src.bloom.model import (
  9. BloomConfig,
  10. BloomForCausalLM,
  11. BloomForSequenceClassification,
  12. BloomModel,
  13. BloomPreTrainedModel,
  14. LMHead,
  15. )
  16. from src.client.remote_generation import RemoteGenerationMixin
  17. from src.client.remote_sequential import RemoteSequential
  18. from src.constants import PUBLIC_INITIAL_PEERS
  19. from src.utils.misc import DUMMY
  20. use_hivemind_log_handler("in_root_logger")
  21. logger = get_logger(__file__)
  22. class DistributedBloomConfig(BloomConfig):
  23. """
  24. A bloom config that contains information about DHT peers.
  25. To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
  26. """
  27. initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
  28. dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
  29. dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
  30. chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
  31. pre_seq_len: int = 0 # a number of tokens for prompt tuning.
  32. tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
  33. class DistributedBloomModel(BloomModel):
  34. """BloomModel, but all transformer layers are hosted by the swarm"""
  35. config_class = DistributedBloomConfig
  36. def __init__(self, config: DistributedBloomConfig):
  37. assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
  38. assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
  39. n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
  40. super().__init__(config)
  41. assert len(self.h) == 0
  42. config.n_layer = n_layer
  43. dht = (
  44. config.dht
  45. if config.dht is not None
  46. else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  47. )
  48. assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
  49. self.h = RemoteSequential(config, dht, config.dht_prefix)
  50. # Forbid accumulate grads for embeddings and layernorm
  51. self.set_requires_grad(False)
  52. if config.tuning_mode and "ptune" in config.tuning_mode:
  53. assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
  54. self.pre_seq_len = config.pre_seq_len
  55. self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
  56. self.prefix_tokens = torch.arange(self.pre_seq_len).long()
  57. if config.tuning_mode == "deep_ptune":
  58. self.intermediate_prompt_embeddings = nn.Embedding(
  59. self.pre_seq_len,
  60. config.num_hidden_layers * config.hidden_size
  61. # ^-- TODO: should be num_hidden_layers - 1
  62. )
  63. self.intermediate_prompt_embeddings.weight.data.zero_()
  64. elif config.tuning_mode:
  65. raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
  66. def set_requires_grad(self, value):
  67. for p in self.parameters():
  68. p.requires_grad = value
  69. def get_prompt(self, batch_size):
  70. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
  71. prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
  72. prompts = self.prompt_embeddings(prefix_tokens)
  73. if self.config.tuning_mode == "deep_ptune":
  74. intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
  75. intermediate_prompts = intermediate_prompts.view(
  76. batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
  77. )
  78. intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
  79. else:
  80. intermediate_prompts = DUMMY
  81. return prompts, intermediate_prompts
  82. def forward(
  83. self,
  84. input_ids: Optional[torch.LongTensor] = None,
  85. inputs_embeds: Optional[torch.Tensor] = None,
  86. attention_mask: Optional[torch.Tensor] = None,
  87. **kwargs,
  88. ):
  89. assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
  90. for k, v in kwargs.items():
  91. if not (v is None or v is False):
  92. logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
  93. if input_ids is not None and inputs_embeds is not None:
  94. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  95. elif input_ids is not None:
  96. input_shape = input_ids.size()
  97. input_ids = input_ids.view(-1, input_shape[-1])
  98. elif inputs_embeds is not None:
  99. input_shape = inputs_embeds.size()[:-1]
  100. else:
  101. raise ValueError("You have to specify either input_ids or inputs_embeds")
  102. if inputs_embeds is None:
  103. inputs_embeds = self.word_embeddings(input_ids)
  104. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  105. batch_size = inputs_embeds.shape[0]
  106. prompts, intermediate_prompts = self.get_prompt(batch_size)
  107. inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
  108. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  109. output_shape = input_shape + (hidden_states.size(-1),)
  110. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  111. hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
  112. else:
  113. hidden_states = self.h(hidden_states)
  114. # Remove prefix
  115. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  116. hidden_states = hidden_states[:, self.pre_seq_len :]
  117. # Add last hidden state
  118. hidden_states = self.ln_f(hidden_states)
  119. hidden_states = hidden_states.view(output_shape)
  120. return BaseModelOutputWithPastAndCrossAttentions(
  121. last_hidden_state=hidden_states,
  122. past_key_values=None,
  123. hidden_states=None,
  124. attentions=None,
  125. )
  126. class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
  127. """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
  128. config_class = DistributedBloomConfig
  129. def __init__(self, config: DistributedBloomConfig):
  130. BloomPreTrainedModel.__init__(self, config)
  131. self.transformer = DistributedBloomModel(config)
  132. self.lm_head = LMHead(config, self.transformer.word_embeddings)
  133. # Initialize weights and apply final processing
  134. self.post_init()
  135. def get_input_embeddings(self):
  136. return self.transformer.word_embeddings
  137. def get_output_embeddings(self):
  138. if self.config.tie_word_embeddings:
  139. return None
  140. return self.lm_head
  141. def set_input_embeddings(self, new_embeddings: nn.Embedding):
  142. assert isinstance(new_embeddings, nn.Embedding)
  143. self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
  144. assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
  145. def set_output_embeddings(self, new_lm_head: nn.Linear):
  146. with torch.no_grad():
  147. self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
  148. self.lm_head.bias[...] = new_lm_head.bias
  149. class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
  150. config_class = DistributedBloomConfig
  151. def __init__(self, config: DistributedBloomConfig):
  152. BloomPreTrainedModel.__init__(self, config)
  153. self.num_labels = config.num_labels
  154. self.transformer = DistributedBloomModel(config)
  155. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  156. # Initialize weights and apply final processing
  157. self.post_init()