remote_model.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # this code is in active development, interfaces may change
  2. import os
  3. from typing import Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. import hivemind
  7. from hivemind import get_logger, use_hivemind_log_handler
  8. from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
  9. from src.client.remote_sequential import RemoteSequential
  10. from src.data_structures import UID_DELIMITER
  11. use_hivemind_log_handler("in_root_logger")
  12. logger = get_logger(__file__)
  13. class DistributedBloomConfig(BloomConfig):
  14. """
  15. A bloom config that contains information about DHT peers.
  16. To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
  17. """
  18. initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
  19. dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
  20. dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
  21. chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
  22. num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
  23. class DistributedBloomModel(BloomModel):
  24. """BloomModel, but all transformer layers are hosted by the swarm"""
  25. config_class = DistributedBloomConfig
  26. def __init__(self, config: DistributedBloomConfig):
  27. assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
  28. assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
  29. n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
  30. super().__init__(config)
  31. assert len(self.h) == 0
  32. config.n_layer = n_layer
  33. dht = (
  34. config.dht
  35. if config.dht is not None
  36. else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  37. )
  38. assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
  39. self.h = RemoteSequential(config, dht, config.dht_prefix)
  40. # Forbid accumulate grads for embeddings and layernorm
  41. self.set_requires_grad(False)
  42. def set_requires_grad(self, value):
  43. for p in self.parameters():
  44. p.requires_grad = value
  45. class DistributedBloomPrefix(DistributedBloomModel):
  46. """DistributedBloomModel with prefix tokens for prompt tuning"""
  47. def __init__(self, config):
  48. super().__init__(config)
  49. assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
  50. self.prefix_length = config.num_prefix_tokens
  51. self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
  52. self.prefix_tokens = torch.arange(self.prefix_length).long()
  53. def get_prompt(self, batch_size):
  54. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
  55. prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
  56. prompts = self.prompt_embeddings(prefix_tokens)
  57. return prompts
  58. def forward(
  59. self,
  60. input_ids: Optional[torch.LongTensor],
  61. inputs_embeds: Optional[torch.Tensor],
  62. attention_mask: Optional[torch.Tensor],
  63. past_key_values=None,
  64. position_ids=None,
  65. head_mask=None,
  66. use_cache=None,
  67. output_attentions=None,
  68. output_hidden_states=None,
  69. return_dict=None
  70. ):
  71. assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
  72. assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
  73. if inputs_embeds is None:
  74. inputs_embeds = self.word_embeddings(input_ids)
  75. batch_size = inputs_embeds.shape[0]
  76. if attention_mask is not None:
  77. prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
  78. attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
  79. prompts = self.get_prompt(batch_size)
  80. inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
  81. transformer_outputs = super().forward(
  82. inputs_embeds=inputs_embeds,
  83. attention_mask=attention_mask,
  84. past_key_values=past_key_values,
  85. position_ids=position_ids,
  86. head_mask=head_mask,
  87. use_cache=use_cache,
  88. output_attentions=output_attentions,
  89. output_hidden_states=output_hidden_states,
  90. return_dict=return_dict
  91. )
  92. # Remove prefix
  93. last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
  94. transformer_outputs['last_hidden_state'] = last_hidden_state
  95. return transformer_outputs
  96. class DistributedBloomForCausalLM(BloomForCausalLM):
  97. """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
  98. config_class = DistributedBloomConfig
  99. def __init__(self, config: DistributedBloomConfig):
  100. BloomPreTrainedModel.__init__(self, config)
  101. if config.num_prefix_tokens > 0:
  102. self.transformer = DistributedBloomPrefix(config)
  103. else:
  104. self.transformer = DistributedBloomModel(config)
  105. self.lm_head = LMHead(config, self.transformer.word_embeddings)
  106. # Initialize weights and apply final processing
  107. self.post_init()
  108. def get_output_embeddings(self):
  109. return self.lm_head.word_embeddings
  110. def set_output_embeddings(self, new_embeddings):
  111. self.lm_head.word_embeddings.weight = new_embeddings.weight
  112. class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
  113. config_class = DistributedBloomConfig
  114. def __init__(self, config: DistributedBloomConfig):
  115. super().__init__(config)
  116. if config.num_prefix_tokens > 0:
  117. self.transformer = DistributedBloomPrefix(config)
  118. else:
  119. self.transformer = DistributedBloomModel(config)
  120. # Initialize weights and apply final processing
  121. self.post_init()