remote_model.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # this code is in active development, interfaces may change
  2. from typing import List, Optional, Tuple
  3. import hivemind
  4. import torch
  5. import torch.nn as nn
  6. from hivemind import get_logger, use_hivemind_log_handler
  7. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
  8. from src.bloom.model import (
  9. BloomConfig,
  10. BloomForCausalLM,
  11. BloomForSequenceClassification,
  12. BloomModel,
  13. BloomPreTrainedModel,
  14. LMHead,
  15. )
  16. from src.client.remote_generation import RemoteGenerationMixin
  17. from src.client.remote_sequential import RemoteSequential
  18. use_hivemind_log_handler("in_root_logger")
  19. logger = get_logger(__file__)
  20. class DistributedBloomConfig(BloomConfig):
  21. """
  22. A bloom config that contains information about DHT peers.
  23. To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
  24. """
  25. initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
  26. dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
  27. dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
  28. chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
  29. pre_seq_len: int = 0 # a number of tokens for prompt tuning.
  30. class DistributedBloomModel(BloomModel):
  31. """BloomModel, but all transformer layers are hosted by the swarm"""
  32. config_class = DistributedBloomConfig
  33. def __init__(self, config: DistributedBloomConfig):
  34. assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
  35. assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
  36. n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
  37. super().__init__(config)
  38. assert len(self.h) == 0
  39. config.n_layer = n_layer
  40. dht = (
  41. config.dht
  42. if config.dht is not None
  43. else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  44. )
  45. assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
  46. self.h = RemoteSequential(config, dht, config.dht_prefix)
  47. # Forbid accumulate grads for embeddings and layernorm
  48. self.set_requires_grad(False)
  49. def set_requires_grad(self, value):
  50. for p in self.parameters():
  51. p.requires_grad = value
  52. def forward(
  53. self,
  54. input_ids: Optional[torch.LongTensor] = None,
  55. inputs_embeds: Optional[torch.Tensor] = None,
  56. attention_mask: Optional[torch.Tensor] = None,
  57. **kwargs,
  58. ):
  59. assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
  60. for k, v in kwargs.items():
  61. if not (v is None or v is False):
  62. logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
  63. if input_ids is not None and inputs_embeds is not None:
  64. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  65. elif input_ids is not None:
  66. input_shape = input_ids.size()
  67. input_ids = input_ids.view(-1, input_shape[-1])
  68. elif inputs_embeds is not None:
  69. input_shape = inputs_embeds.size()[:-1]
  70. else:
  71. raise ValueError("You have to specify either input_ids or inputs_embeds")
  72. if inputs_embeds is None:
  73. inputs_embeds = self.word_embeddings(input_ids)
  74. # Note: it supports only float32 or bfloat16 inputs
  75. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  76. output_shape = input_shape + (hidden_states.size(-1),)
  77. hidden_states = self.h(hidden_states)
  78. # Add last hidden state
  79. hidden_states = self.ln_f(hidden_states)
  80. hidden_states = hidden_states.view(output_shape)
  81. return BaseModelOutputWithPastAndCrossAttentions(
  82. last_hidden_state=hidden_states,
  83. past_key_values=None,
  84. hidden_states=None,
  85. attentions=None,
  86. )
  87. class DistributedBloomPrefix(DistributedBloomModel):
  88. """DistributedBloomModel with prefix tokens for prompt tuning"""
  89. def __init__(self, config):
  90. super().__init__(config)
  91. assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
  92. self.pre_seq_len = config.pre_seq_len
  93. self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
  94. self.prefix_tokens = torch.arange(self.pre_seq_len).long()
  95. def get_prompt(self, batch_size):
  96. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
  97. prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
  98. prompts = self.prompt_embeddings(prefix_tokens)
  99. return prompts
  100. def forward(
  101. self,
  102. input_ids: Optional[torch.LongTensor] = None,
  103. inputs_embeds: Optional[torch.Tensor] = None,
  104. attention_mask: Optional[torch.Tensor] = None,
  105. **kwargs,
  106. ):
  107. assert (
  108. input_ids is None or inputs_embeds is None
  109. ), "You cannot specify both input_ids and inputs_embeds at the same time"
  110. assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
  111. if inputs_embeds is None:
  112. inputs_embeds = self.word_embeddings(input_ids)
  113. batch_size = inputs_embeds.shape[0]
  114. if attention_mask is not None:
  115. prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
  116. attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
  117. prompts = self.get_prompt(batch_size)
  118. inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
  119. transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
  120. # Remove prefix
  121. last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
  122. transformer_outputs["last_hidden_state"] = last_hidden_state
  123. return transformer_outputs
  124. class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
  125. """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
  126. config_class = DistributedBloomConfig
  127. def __init__(self, config: DistributedBloomConfig):
  128. BloomPreTrainedModel.__init__(self, config)
  129. if config.pre_seq_len > 0:
  130. self.transformer = DistributedBloomPrefix(config)
  131. else:
  132. self.transformer = DistributedBloomModel(config)
  133. self.lm_head = LMHead(config, self.transformer.word_embeddings)
  134. # Initialize weights and apply final processing
  135. self.post_init()
  136. def get_input_embeddings(self):
  137. return self.transformer.word_embeddings
  138. def get_output_embeddings(self):
  139. if self.config.tie_word_embeddings:
  140. return None
  141. return self.lm_head
  142. def set_input_embeddings(self, new_embeddings: nn.Embedding):
  143. assert isinstance(new_embeddings, nn.Embedding)
  144. self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
  145. assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
  146. def set_output_embeddings(self, new_lm_head: nn.Linear):
  147. with torch.no_grad():
  148. self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
  149. self.lm_head.bias[...] = new_lm_head.bias
  150. class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
  151. config_class = DistributedBloomConfig
  152. def __init__(self, config: DistributedBloomConfig):
  153. super().__init__(config)
  154. if config.pre_seq_len > 0:
  155. self.transformer = DistributedBloomPrefix(config)
  156. else:
  157. self.transformer = DistributedBloomModel(config)
  158. # Initialize weights and apply final processing
  159. self.post_init()