remote_model.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # this code is in active development, interfaces may change
  2. import os
  3. import torch
  4. from typing import List, Optional, Tuple, Union
  5. import torch
  6. import hivemind
  7. import torch
  8. import torch.nn as nn
  9. from hivemind import get_logger, use_hivemind_log_handler
  10. from src.bloom.model import (
  11. BloomConfig,
  12. BloomForCausalLM,
  13. BloomForSequenceClassification,
  14. BloomModel,
  15. BloomPreTrainedModel,
  16. LMHead,
  17. )
  18. from src.client.remote_sequential import RemoteSequential
  19. from src.client.remote_generation import RemoteGenerationMixin
  20. from src.utils.generation_algorithms import DecodingAlgorithm
  21. from src.utils.generation_constraints import ABConstraint
  22. use_hivemind_log_handler("in_root_logger")
  23. logger = get_logger(__file__)
  24. class DistributedBloomConfig(BloomConfig):
  25. """
  26. A bloom config that contains information about DHT peers.
  27. To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
  28. """
  29. initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
  30. dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
  31. dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
  32. chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
  33. num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
  34. class DistributedBloomModel(BloomModel):
  35. """BloomModel, but all transformer layers are hosted by the swarm"""
  36. config_class = DistributedBloomConfig
  37. def __init__(self, config: DistributedBloomConfig):
  38. assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
  39. assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
  40. n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
  41. super().__init__(config)
  42. assert len(self.h) == 0
  43. config.n_layer = n_layer
  44. dht = (
  45. config.dht
  46. if config.dht is not None
  47. else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  48. )
  49. assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
  50. self.h = RemoteSequential(config, dht, config.dht_prefix)
  51. # Forbid accumulate grads for embeddings and layernorm
  52. self.set_requires_grad(False)
  53. def set_requires_grad(self, value):
  54. for p in self.parameters():
  55. p.requires_grad = value
  56. def forward(self, *args, use_cache=None, **kwargs):
  57. if use_cache:
  58. raise ValueError(
  59. "Distributed forward does not support use_cache; for efficient cache-aware generation, "
  60. "please use model.transformer.inference_session() or model.generate(...)"
  61. )
  62. return super().forward(*args, use_cache=False, **kwargs)
  63. class DistributedBloomPrefix(DistributedBloomModel):
  64. """DistributedBloomModel with prefix tokens for prompt tuning"""
  65. def __init__(self, config):
  66. super().__init__(config)
  67. assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
  68. self.prefix_length = config.num_prefix_tokens
  69. self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
  70. self.prefix_tokens = torch.arange(self.prefix_length).long()
  71. def get_prompt(self, batch_size):
  72. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
  73. prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
  74. prompts = self.prompt_embeddings(prefix_tokens)
  75. return prompts
  76. def forward(
  77. self,
  78. input_ids: Optional[torch.LongTensor],
  79. inputs_embeds: Optional[torch.Tensor],
  80. attention_mask: Optional[torch.Tensor],
  81. past_key_values=None,
  82. position_ids=None,
  83. head_mask=None,
  84. use_cache=None,
  85. output_attentions=None,
  86. output_hidden_states=None,
  87. return_dict=None,
  88. ):
  89. assert (
  90. input_ids is None or inputs_embeds is None
  91. ), "You cannot specify both input_ids and inputs_embeds at the same time"
  92. assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
  93. if inputs_embeds is None:
  94. inputs_embeds = self.word_embeddings(input_ids)
  95. batch_size = inputs_embeds.shape[0]
  96. if attention_mask is not None:
  97. prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
  98. attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
  99. prompts = self.get_prompt(batch_size)
  100. inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
  101. transformer_outputs = super().forward(
  102. inputs_embeds=inputs_embeds,
  103. attention_mask=attention_mask,
  104. past_key_values=past_key_values,
  105. position_ids=position_ids,
  106. head_mask=head_mask,
  107. use_cache=use_cache,
  108. output_attentions=output_attentions,
  109. output_hidden_states=output_hidden_states,
  110. return_dict=return_dict,
  111. )
  112. # Remove prefix
  113. last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
  114. transformer_outputs["last_hidden_state"] = last_hidden_state
  115. return transformer_outputs
  116. class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
  117. """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
  118. config_class = DistributedBloomConfig
  119. def __init__(self, config: DistributedBloomConfig):
  120. BloomPreTrainedModel.__init__(self, config)
  121. if config.num_prefix_tokens > 0:
  122. self.transformer = DistributedBloomPrefix(config)
  123. else:
  124. self.transformer = DistributedBloomModel(config)
  125. self.lm_head = LMHead(config, self.transformer.word_embeddings)
  126. # Initialize weights and apply final processing
  127. self.post_init()
  128. def get_input_embeddings(self):
  129. return self.transformer.word_embeddings
  130. def get_output_embeddings(self):
  131. if self.config.tie_word_embeddings:
  132. return None
  133. return self.lm_head
  134. def set_input_embeddings(self, new_embeddings: nn.Embedding):
  135. assert isinstance(new_embeddings, nn.Embedding)
  136. self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
  137. assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
  138. def set_output_embeddings(self, new_lm_head: nn.Linear):
  139. with torch.no_grad():
  140. self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
  141. self.lm_head.bias[...] = new_lm_head.bias
  142. def generate(
  143. self,
  144. inputs: Optional[torch.Tensor] = None,
  145. do_sample: Optional[bool] = None,
  146. temperature: float = 1.0,
  147. top_k: Optional[int] = None,
  148. top_p: Optional[float] = None,
  149. eos_token_id: Optional[int] = None,
  150. max_new_tokens: Optional[int] = None,
  151. decoding_algorithm: Optional[DecodingAlgorithm] = None,
  152. provided_constraints: List[ABConstraint] = [],
  153. **model_kwargs,
  154. ) -> torch.Tensor:
  155. return RemoteGenerationMixin.generate(
  156. self,
  157. inputs,
  158. do_sample,
  159. temperature,
  160. top_k,
  161. top_p,
  162. eos_token_id,
  163. max_new_tokens,
  164. decoding_algorithm,
  165. provided_constraints,
  166. **model_kwargs,
  167. )
  168. class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
  169. config_class = DistributedBloomConfig
  170. def __init__(self, config: DistributedBloomConfig):
  171. super().__init__(config)
  172. if config.num_prefix_tokens > 0:
  173. self.transformer = DistributedBloomPrefix(config)
  174. else:
  175. self.transformer = DistributedBloomModel(config)
  176. # Initialize weights and apply final processing
  177. self.post_init()