Procházet zdrojové kódy

WIP: make DistributedBloom compliant with HF interface

justheuristic před 3 roky
rodič
revize
4695071ad2
5 změnil soubory, kde provedl 135 přidání a 138 odebrání
  1. 2 0
      README.md
  2. 2 2
      cli/inference_one_block.py
  3. 4 20
      src/bloom/block.py
  4. 99 13
      src/bloom/model.py
  5. 28 103
      src/client/remote_model.py

+ 2 - 0
README.md

@@ -4,6 +4,8 @@ Early dev prototype for decentralized bloom. Not for public eyes **yet**.
 Roadmap: [issue #12](https://github.com/learning-at-home/bloom-demo/issues/12)
 
 Latest news @ main branch (max 5):
+- [Jul 4] @dbaranchuk implemented chained rpc_forward and rpc_backward (for prompt tuning)
+- [Jul 3] @dbaranchuk optimized DistributedBloom to reduce embeddings/logits RAM usage
 - [Jul 1] @yozh added RemoteSequential and test for full model exact match
 - [June 28] @dbaranchunk added quick deployment scripts for testnet
 

+ 2 - 2
cli/inference_one_block.py

@@ -5,7 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tqdm.auto import trange
 
 from src.bloom.block import BloomBlock
-from src.bloom.model import DistributedBloomConfig
+from src.bloom.model import BloomConfig
 from src.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
@@ -39,7 +39,7 @@ if __name__ == "__main__":
     if args.device is None:
         args.device = "cuda" if torch.cuda.is_available() else "cpu"
 
-    config = DistributedBloomConfig.from_json_file(args.config)
+    config = BloomConfig.from_json_file(args.config)
     block = BloomBlock(config, args.layer_index).to(args.device)
 
     cache = None

+ 4 - 20
src/bloom/block.py

@@ -43,16 +43,8 @@ class BloomAttention(nn.Module):
             self.layer_number,
         )
 
-        if config.compression == "qint8":
-            self.query_key_value = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-            self.dense = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-        else:
-            self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
-            self.dense = nn.Linear(self.hidden_size, self.hidden_size)
+        self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
+        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
 
         self.attention_dropout = nn.Dropout(config.attention_dropout)
 
@@ -173,16 +165,8 @@ class BloomMLP(nn.Module):
     def __init__(self, config):
         super().__init__()
         self.hidden_size = config.hidden_size
-        if config.compression == "qint8":
-            self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-            self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
-                4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
-            )
-        else:
-            self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
-            self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
+        self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
+        self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
         self.hidden_dropout = config.hidden_dropout
         self.gelu_impl = BloomGelu()
 

+ 99 - 13
src/bloom/model.py

@@ -3,8 +3,10 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
+from typing import Tuple
 
 import torch
+import torch.nn.functional as F
 import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
@@ -13,25 +15,19 @@ from transformers.file_utils import (add_code_sample_docstrings, add_start_docst
                                      add_start_docstrings_to_model_forward)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
-from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
+from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.utils import logging
 
 from src.bloom.block import BloomBlock
-from src.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
 logger = logging.get_logger(__file__)
 
 _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
-_CONFIG_FOR_DOC = "DistributedBloomConfig"
+_CONFIG_FOR_DOC = "BloomConfig"
 _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 
-class DistributedBloomConfig(_VanillaBloomConfig):
-    compression: str = "none"
-    slow_but_exact: bool = False
-
-
 class BloomPreTrainedModel(PreTrainedModel):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
     """
@@ -39,7 +35,7 @@ class BloomPreTrainedModel(PreTrainedModel):
     models.
     """
 
-    config_class = DistributedBloomConfig
+    config_class = BloomConfig
     base_model_prefix = "transformer"
     supports_gradient_checkpointing = True
     _no_split_modules = ["BloomBlock"]
@@ -312,17 +308,107 @@ class BloomModel(BloomPreTrainedModel):
 
 @add_start_docstrings(
     """
-    The Bloom interface for various applications, e.g., inference, classification...
+    The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+    embeddings).
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForYou(BloomPreTrainedModel):
+class BloomForCausalLM(BloomPreTrainedModel):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
     def __init__(self, config):
         super().__init__(config)
         self.transformer = BloomModel(config)
-        self.lm_head = None
-
         # Initialize weights and apply final processing
         self.post_init()
+
+    def get_output_embeddings(self):
+        return self.transformer.word_embeddings
+
+    def set_output_embeddings(self, new_embeddings):
+        self.transformer.word_embeddings.weight = new_embeddings.weight
+
+    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+        # only last token for inputs_ids if past is defined in kwargs
+        if past:
+            input_ids = input_ids[:, -1].unsqueeze(-1)
+
+        attention_mask = kwargs.get("attention_mask", None)
+        position_ids = kwargs.get("position_ids", None)
+
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past:
+                position_ids = position_ids[:, -1].unsqueeze(-1)
+        else:
+            position_ids = None
+        return {
+            "input_ids": input_ids,
+            "past_key_values": past,
+            "use_cache": kwargs.get("use_cache"),
+            "position_ids": position_ids,
+            "attention_mask": attention_mask,
+        }
+
+    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        processor_class=_TOKENIZER_FOR_DOC,
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=CausalLMOutputWithCrossAttentions,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def forward(
+        self,
+        input_ids=None,
+        labels=None,
+        return_dict=None,
+        **kwargs
+    ):
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
+        word_embeddings = self.transformer.word_embeddings.weight
+
+        # Switch dtype in case word_embeddings are fp16/bf16
+        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
+        lm_logits = F.linear(hidden_states, word_embeddings).float()
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = lm_logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
+
+    @staticmethod
+    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+        """
+        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+        beam_idx at every generation step.
+        """
+        return tuple(
+            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+            for layer_past in past
+        )

+ 28 - 103
src/client/remote_model.py

@@ -1,15 +1,13 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Tuple, Union
+from typing import Optional, Union, Tuple
 
 import hivemind
-import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
-from torch.nn import CrossEntropyLoss
-from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
 
-from src.bloom import BloomForYou, DistributedBloomConfig
+from src.bloom.model import BloomModel, BloomForCausalLM, BloomConfig
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
+from src.bloom.model import BloomPreTrainedModel
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -17,111 +15,38 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class DistributedBloomForYou(BloomForYou):
+class DistributedBloomConfig(BloomConfig):
+    """
+    A bloom config that contains information about DHT peers.
+    To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
+    """
+    initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
+    dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
+    dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
+
+
+class DistributedBloomModel(BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
+    def __init__(self, config: DistributedBloomConfig):
+        assert self.config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
+        assert self.config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
 
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
-        assert len(self.transformer.h) == 0
+        assert len(self.h) == 0
         config.n_layer = n_layer
-        self.transformer.h = RemoteSequential(config, dht, prefix)
-
-    @classmethod
-    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
-        if "initial_peers" not in kwargs:
-            raise ValueError("Please specify initial_peers=...")
-
-        dht = hivemind.DHT(
-            initial_peers=kwargs.pop("initial_peers"), client_mode=kwargs.pop("client_mode", True), start=True
-        )
 
-        if "prefix" not in kwargs:
-            logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
-            assert (
-                UID_DELIMITER not in pretrained_model_name_or_path
-            ), f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
-        prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
+        dht = config.dht if config.dht is not None else hivemind.DHT(
+            initial_peers=config.initial_peers, client_mode=True, start=True)
+        assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
+        self.h = RemoteSequential(config, dht, config.dht_prefix)
 
-        config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
-        model = cls(config, dht, prefix)
-        model.transformer.load_state_dict(
-            _load_state_dict(pretrained_model_name_or_path, use_auth_token=kwargs.get("use_auth_token")), strict=True
-        )
-        return model
 
-
-class DistributedBloomForCausalLM(DistributedBloomForYou):
+class DistributedBloomForCausalLM(BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel().__init__(config)
+        self.transformer = DistributedBloomModel(config)
+        # Initialize weights and apply final processing
+        self.post_init()
 
-    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
-        # only last token for inputs_ids if past is defined in kwargs
-        if past:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-        }
-
-    def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
-        r"""
-        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
-            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
-            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
-        """
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-        transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
-
-        # Switch dtype in case word_embeddings are fp16
-        word_embeddings = self.transformer.word_embeddings.weight.t()
-        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
-        lm_logits = (hidden_states @ word_embeddings).float()
-
-        loss = None
-        if labels is not None:
-            # Shift so that tokens < n predict n
-            shift_logits = lm_logits[..., :-1, :].contiguous()
-            shift_labels = labels[..., 1:].contiguous()
-            # Flatten the tokens
-            loss_fct = CrossEntropyLoss()
-            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-
-        if not return_dict:
-            output = (lm_logits,) + transformer_outputs[1:]
-            return ((loss,) + output) if loss is not None else output
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=loss,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-        )
-
-    @staticmethod
-    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
-        """
-        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
-        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
-        beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )