Browse Source

design interface & refactoring

Dmitry Baranchuk 3 năm trước cách đây
mục cha
commit
e66ab6f1f2

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomBlock, BloomForCausalLM, BloomModel, DistributedBloomConfig
+from src.bloom.model import BloomBlock, BloomForYou, BloomModel, DistributedBloomConfig

+ 1 - 2
src/bloom/from_pretrained.py

@@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.utils.hub import cached_path, hf_bucket_url
 
-from src.bloom import BloomBlock, BloomForCausalLM, DistributedBloomConfig
+from src.bloom import BloomBlock, DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -23,7 +23,6 @@ logger = get_logger(__file__)
 CLIENT_BRANCH = "client"
 BLOCK_BRANCH_PREFIX = "block_"
 USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
-cls = BloomForCausalLM
 FORCE_DOWNLOAD = False
 RESUME_DOWNLOAD = False
 LOCAL_FILES_ONLY = False

+ 8 - 107
src/bloom/model.py

@@ -4,8 +4,6 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
 See commit history for authorship.
 """
 
-from typing import Tuple
-
 import torch
 import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
@@ -314,114 +312,17 @@ class BloomModel(BloomPreTrainedModel):
 
 @add_start_docstrings(
     """
-    The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
-    embeddings).
+    The Bloom interface for various applications, e.g., inference, classification...
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForCausalLM(BloomModel):
+class BloomForYou(BloomPreTrainedModel):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
-    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
-        # only last token for inputs_ids if past is defined in kwargs
-        if past:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-        }
-
-    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
-    @add_code_sample_docstrings(
-        processor_class=_TOKENIZER_FOR_DOC,
-        checkpoint=_CHECKPOINT_FOR_DOC,
-        output_type=CausalLMOutputWithCrossAttentions,
-        config_class=_CONFIG_FOR_DOC,
-    )
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        labels=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        r"""
-        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
-            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
-            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
-        """
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        transformer_outputs = super().forward(
-            input_ids,
-            past_key_values=past_key_values,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            inputs_embeds=inputs_embeds,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-
-        # Switch dtype in case word_embeddings are fp16
-        word_embeddings = self.word_embeddings.weight.t()
-        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
-        lm_logits = (hidden_states @ word_embeddings).float()
-
-        loss = None
-        if labels is not None:
-            # Shift so that tokens < n predict n
-            shift_logits = lm_logits[..., :-1, :].contiguous()
-            shift_labels = labels[..., 1:].contiguous()
-            # Flatten the tokens
-            loss_fct = CrossEntropyLoss()
-            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-
-        if not return_dict:
-            output = (lm_logits,) + transformer_outputs[1:]
-            return ((loss,) + output) if loss is not None else output
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=loss,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-        )
+    def __init__(self, config):
+         super().__init__(config)
+         self.transformer = BloomModel(config)
+         self.lm_head = None 
 
-    @staticmethod
-    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
-        """
-        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
-        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
-        beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )
+         # Initialize weights and apply final processing
+         self.post_init()

+ 89 - 7
src/client/remote_model.py

@@ -1,28 +1,33 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Union
+from typing import Optional, Union, Tuple
 
 import hivemind
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
-from src.bloom import BloomModel, BloomForCausalLM, DistributedBloomConfig
+from src.bloom import BloomForYou, DistributedBloomConfig
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
+import torch
+from hivemind import use_hivemind_log_handler
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM):
-    """BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+class DistributedBloomForYou(BloomForYou):
+    """BloomModel, but all transformer layers are hosted by the swarm"""
 
     def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
-        assert len(self.h) == 0
+        assert len(self.transformer.h) == 0
         config.n_layer = n_layer
-        self.h = RemoteSequential(config, dht, prefix)
+        self.transformer.h = RemoteSequential(config, dht, prefix)
 
     @classmethod
     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
@@ -41,7 +46,84 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
 
         config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
         model = cls(config, dht, prefix)
-        model.load_state_dict(_load_state_dict(
+        model.transformer.load_state_dict(_load_state_dict(
             pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
         ), strict=True) 
         return model
+
+
+class DistributedBloomForCausalLM(DistributedBloomForYou):
+    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+        # only last token for inputs_ids if past is defined in kwargs
+        if past:
+            input_ids = input_ids[:, -1].unsqueeze(-1)
+
+        attention_mask = kwargs.get("attention_mask", None)
+        position_ids = kwargs.get("position_ids", None)
+
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past:
+                position_ids = position_ids[:, -1].unsqueeze(-1)
+        else:
+            position_ids = None
+        return {
+            "input_ids": input_ids,
+            "past_key_values": past,
+            "use_cache": kwargs.get("use_cache"),
+            "position_ids": position_ids,
+            "attention_mask": attention_mask,
+        }
+
+    def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        transformer_outputs = self.transformer.forward(
+            input_ids=input_ids, return_dict=return_dict, **kwargs
+        )
+
+        # Switch dtype in case word_embeddings are fp16
+        word_embeddings = self.transformer.word_embeddings.weight.t()
+        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
+        lm_logits = (hidden_states @ word_embeddings).float()
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = lm_logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
+
+    @staticmethod
+    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+        """
+        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+        beam_idx at every generation step.
+        """
+        return tuple(
+            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+            for layer_past in past
+        )

+ 6 - 3
tests/test_full_model.py

@@ -24,9 +24,9 @@ if not MODEL_NAME:
 REF_NAME = os.environ.get("REF_NAME")
 
 
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
@@ -52,6 +52,9 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
             recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
     recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
     recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-    recurrent_outputs = model.lm_head(recurrent_outputs)
+
+    dictionary = model.transformer.word_embeddings.weight.t()
+    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
+    recurrent_outputs = (recurrent_outputs @ dictionary).float()
     assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
     logger.info("Inference is consistent with forward")

+ 0 - 60
tests/test_full_model_new_model.py

@@ -1,60 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
-
-import torch
-import transformers
-from hivemind import use_hivemind_log_handler, get_logger
-
-from src.client.remote_model import DistributedBloomForCausalLM
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME")
-
-
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
-    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix="bloom6b3")
-    assert len(model.h) == model.config.n_layer
-
-    test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
-    parallel_outputs = model.forward(test_inputs).logits
-    assert torch.all(torch.isfinite(parallel_outputs))
-    logger.info("Forward outputs are finite")
-
-    if REF_NAME:
-        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-        # note: this creates a dummy mask to make the test compatible with older transformer versions
-        # prior to https://github.com/huggingface/transformers/pull/17837
-        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
-    else:
-        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
-
-    embs = model.word_embeddings(test_inputs)
-    embs = model.word_embeddings_layernorm(embs.float())
-    recurrent_outputs = []
-    with model.h.inference_session() as sess:
-        for t in range(embs.shape[1]):
-            recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
-    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
-    recurrent_outputs = model.ln_f(recurrent_outputs)
-    
-    dictionary = model.word_embeddings.weight.t()
-    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-    recurrent_outputs = (recurrent_outputs @ dictionary).float()
-    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
-    logger.info("Inference is consistent with forward")