remote_model.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # this code is in active development, interfaces may change
  2. from typing import Optional, Tuple
  3. import hivemind
  4. import torch
  5. import torch.nn as nn
  6. from hivemind import get_logger, use_hivemind_log_handler
  7. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
  8. from src.bloom.model import (
  9. BloomConfig,
  10. BloomForCausalLM,
  11. BloomForSequenceClassification,
  12. BloomModel,
  13. BloomPreTrainedModel,
  14. LMHead,
  15. )
  16. from src.client.remote_generation import RemoteGenerationMixin
  17. from src.client.remote_sequential import RemoteSequential
  18. from src.utils.misc import DUMMY
  19. use_hivemind_log_handler("in_root_logger")
  20. logger = get_logger(__file__)
  21. class DistributedBloomConfig(BloomConfig):
  22. """
  23. A bloom config that contains information about DHT peers.
  24. To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
  25. """
  26. initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
  27. dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
  28. dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
  29. chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
  30. pre_seq_len: int = 0 # a number of tokens for prompt tuning.
  31. tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
  32. class DistributedBloomModel(BloomModel):
  33. """BloomModel, but all transformer layers are hosted by the swarm"""
  34. config_class = DistributedBloomConfig
  35. def __init__(self, config: DistributedBloomConfig):
  36. assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
  37. assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
  38. n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
  39. super().__init__(config)
  40. assert len(self.h) == 0
  41. config.n_layer = n_layer
  42. dht = (
  43. config.dht
  44. if config.dht is not None
  45. else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  46. )
  47. assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
  48. self.h = RemoteSequential(config, dht, config.dht_prefix)
  49. # Forbid accumulate grads for embeddings and layernorm
  50. self.set_requires_grad(False)
  51. if config.tuning_mode and "ptune" in config.tuning_mode:
  52. assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
  53. self.pre_seq_len = config.pre_seq_len
  54. self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
  55. self.prefix_tokens = torch.arange(self.pre_seq_len).long()
  56. if config.tuning_mode == "deep_ptune":
  57. self.intermediate_prompt_embeddings = nn.Embedding(
  58. self.pre_seq_len,
  59. config.num_hidden_layers * config.hidden_size
  60. # ^-- TODO: should be num_hidden_layers - 1
  61. )
  62. self.intermediate_prompt_embeddings.weight.data.zero_()
  63. elif config.tuning_mode:
  64. raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
  65. def set_requires_grad(self, value):
  66. for p in self.parameters():
  67. p.requires_grad = value
  68. def get_prompt(self, batch_size):
  69. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
  70. prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
  71. prompts = self.prompt_embeddings(prefix_tokens)
  72. if self.config.tuning_mode == "deep_ptune":
  73. intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
  74. intermediate_prompts = intermediate_prompts.view(
  75. batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
  76. )
  77. intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
  78. else:
  79. intermediate_prompts = DUMMY
  80. return prompts, intermediate_prompts
  81. def forward(
  82. self,
  83. input_ids: Optional[torch.LongTensor] = None,
  84. inputs_embeds: Optional[torch.Tensor] = None,
  85. attention_mask: Optional[torch.Tensor] = None,
  86. **kwargs,
  87. ):
  88. assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
  89. for k, v in kwargs.items():
  90. if not (v is None or v is False):
  91. logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
  92. if input_ids is not None and inputs_embeds is not None:
  93. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  94. elif input_ids is not None:
  95. input_shape = input_ids.size()
  96. input_ids = input_ids.view(-1, input_shape[-1])
  97. elif inputs_embeds is not None:
  98. input_shape = inputs_embeds.size()[:-1]
  99. else:
  100. raise ValueError("You have to specify either input_ids or inputs_embeds")
  101. if inputs_embeds is None:
  102. inputs_embeds = self.word_embeddings(input_ids)
  103. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  104. batch_size = inputs_embeds.shape[0]
  105. prompts, intermediate_prompts = self.get_prompt(batch_size)
  106. inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
  107. hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
  108. output_shape = input_shape + (hidden_states.size(-1),)
  109. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  110. hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
  111. else:
  112. hidden_states = self.h(hidden_states)
  113. # Remove prefix
  114. if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
  115. hidden_states = hidden_states[:, self.pre_seq_len :]
  116. # Add last hidden state
  117. hidden_states = self.ln_f(hidden_states)
  118. hidden_states = hidden_states.view(output_shape)
  119. return BaseModelOutputWithPastAndCrossAttentions(
  120. last_hidden_state=hidden_states,
  121. past_key_values=None,
  122. hidden_states=None,
  123. attentions=None,
  124. )
  125. class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
  126. """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
  127. config_class = DistributedBloomConfig
  128. def __init__(self, config: DistributedBloomConfig):
  129. BloomPreTrainedModel.__init__(self, config)
  130. self.transformer = DistributedBloomModel(config)
  131. self.lm_head = LMHead(config, self.transformer.word_embeddings)
  132. # Initialize weights and apply final processing
  133. self.post_init()
  134. def get_input_embeddings(self):
  135. return self.transformer.word_embeddings
  136. def get_output_embeddings(self):
  137. if self.config.tie_word_embeddings:
  138. return None
  139. return self.lm_head
  140. def set_input_embeddings(self, new_embeddings: nn.Embedding):
  141. assert isinstance(new_embeddings, nn.Embedding)
  142. self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
  143. assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
  144. def set_output_embeddings(self, new_lm_head: nn.Linear):
  145. with torch.no_grad():
  146. self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
  147. self.lm_head.bias[...] = new_lm_head.bias
  148. class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
  149. config_class = DistributedBloomConfig
  150. def __init__(self, config: DistributedBloomConfig):
  151. super().__init__(config)
  152. self.transformer = DistributedBloomModel(config)
  153. # Initialize weights and apply final processing
  154. self.post_init()