Browse Source

Merge pull request #17 from learning-at-home/update-model

Update client and model
Dmitry Baranchuk 3 years ago
parent
commit
0b5a68983f

+ 3 - 3
cli/convert_model.py

@@ -48,7 +48,7 @@ if __name__ == "__main__":
     config = transformers.AutoConfig.from_pretrained(
     config = transformers.AutoConfig.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
     )
     )
-    model = transformers.AutoModelForCausalLM.from_pretrained(
+    model = transformers.AutoModel.from_pretrained(    
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
     )
     tokenizer = transformers.AutoTokenizer.from_pretrained(
     tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -59,7 +59,7 @@ if __name__ == "__main__":
     repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
     repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
     repo.git_pull()
     repo.git_pull()
 
 
-    transformer_blocks = model.transformer.h
+    transformer_blocks = model.h
     logger.info(
     logger.info(
         f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
         f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
         f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
         f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
@@ -74,7 +74,7 @@ if __name__ == "__main__":
     logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
     logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
     repo.git_checkout(args.base_branch, create_branch_ok=True)
     repo.git_checkout(args.base_branch, create_branch_ok=True)
     with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
     with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
-        model.transformer.h = nn.ModuleList()
+        model.h = nn.ModuleList()
         model.save_pretrained(".")
         model.save_pretrained(".")
 
 
     logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
     logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")

+ 1 - 1
cli/run_local_servers.sh

@@ -98,7 +98,7 @@ do
     # Run server #
     # Run server #
     ##############
     ##############
 
 
-    tmux new-session -d -s "Server_${SERVER_ID}" bash deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
+    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
 done
 done
 
 
 
 

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomBlock, BloomForCausalLM, BloomModel, DistributedBloomConfig
+from src.bloom.model import BloomBlock, BloomForYou, BloomModel, DistributedBloomConfig

+ 1 - 2
src/bloom/from_pretrained.py

@@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.utils.hub import cached_path, hf_bucket_url
 from transformers.utils.hub import cached_path, hf_bucket_url
 
 
-from src.bloom import BloomBlock, BloomForCausalLM, DistributedBloomConfig
+from src.bloom import BloomBlock, DistributedBloomConfig
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -23,7 +23,6 @@ logger = get_logger(__file__)
 CLIENT_BRANCH = "client"
 CLIENT_BRANCH = "client"
 BLOCK_BRANCH_PREFIX = "block_"
 BLOCK_BRANCH_PREFIX = "block_"
 USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
 USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
-cls = BloomForCausalLM
 FORCE_DOWNLOAD = False
 FORCE_DOWNLOAD = False
 RESUME_DOWNLOAD = False
 RESUME_DOWNLOAD = False
 LOCAL_FILES_ONLY = False
 LOCAL_FILES_ONLY = False

+ 18 - 121
src/bloom/model.py

@@ -4,8 +4,6 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
 See commit history for authorship.
 See commit history for authorship.
 """
 """
 
 
-from typing import Tuple
-
 import torch
 import torch
 import torch.utils.checkpoint
 import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from hivemind import use_hivemind_log_handler
@@ -155,8 +153,9 @@ class BloomModel(BloomPreTrainedModel):
         self.n_head = config.n_head
         self.n_head = config.n_head
 
 
         # Embedding + LN Embedding
         # Embedding + LN Embedding
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
-
+        
+        # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
 
         # Transformer blocks
         # Transformer blocks
@@ -170,11 +169,18 @@ class BloomModel(BloomPreTrainedModel):
         # Initialize weights and apply final processing
         # Initialize weights and apply final processing
         self.post_init()
         self.post_init()
 
 
+        # Forbid accumulate grads for embeddings and layernorm
+        self.set_requires_grad(False)
+
     def get_input_embeddings(self):
     def get_input_embeddings(self):
         return self.word_embeddings
         return self.word_embeddings
 
 
     def set_input_embeddings(self, new_embeddings):
     def set_input_embeddings(self, new_embeddings):
         self.word_embeddings = new_embeddings
         self.word_embeddings = new_embeddings
+    
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad=value
 
 
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(
     @add_code_sample_docstrings(
@@ -227,7 +233,7 @@ class BloomModel(BloomPreTrainedModel):
         if inputs_embeds is None:
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
             inputs_embeds = self.word_embeddings(input_ids)
 
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
 
 
         output_shape = input_shape + (hidden_states.size(-1),)
         output_shape = input_shape + (hidden_states.size(-1),)
 
 
@@ -306,126 +312,17 @@ class BloomModel(BloomPreTrainedModel):
 
 
 @add_start_docstrings(
 @add_start_docstrings(
     """
     """
-    The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
-    embeddings).
+    The Bloom interface for various applications, e.g., inference, classification...
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomForCausalLM(BloomPreTrainedModel):
+class BloomForYou(BloomPreTrainedModel):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
     def __init__(self, config):
     def __init__(self, config):
-        super().__init__(config)
-        self.transformer = BloomModel(config)
-        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_output_embeddings(self):
-        return self.lm_head
-
-    def set_output_embeddings(self, new_embeddings):
-        self.lm_head = new_embeddings
-
-    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
-        # only last token for inputs_ids if past is defined in kwargs
-        if past:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-        }
-
-    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
-    @add_code_sample_docstrings(
-        processor_class=_TOKENIZER_FOR_DOC,
-        checkpoint=_CHECKPOINT_FOR_DOC,
-        output_type=CausalLMOutputWithCrossAttentions,
-        config_class=_CONFIG_FOR_DOC,
-    )
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        labels=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        r"""
-        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
-            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
-            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
-        """
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        transformer_outputs = self.transformer(
-            input_ids,
-            past_key_values=past_key_values,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            inputs_embeds=inputs_embeds,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-        hidden_states = transformer_outputs[0]
-
-        lm_logits = self.lm_head(hidden_states)
+         super().__init__(config)
+         self.transformer = BloomModel(config)
+         self.lm_head = None 
 
 
-        loss = None
-        if labels is not None:
-            # Shift so that tokens < n predict n
-            shift_logits = lm_logits[..., :-1, :].contiguous()
-            shift_labels = labels[..., 1:].contiguous()
-            # Flatten the tokens
-            loss_fct = CrossEntropyLoss()
-            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-
-        if not return_dict:
-            output = (lm_logits,) + transformer_outputs[1:]
-            return ((loss,) + output) if loss is not None else output
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=loss,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-        )
-
-    @staticmethod
-    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
-        """
-        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
-        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
-        beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )
+         # Initialize weights and apply final processing
+         self.post_init()

+ 86 - 6
src/client/remote_model.py

@@ -1,21 +1,26 @@
 # this code is in active development, interfaces may change
 # this code is in active development, interfaces may change
 import os
 import os
-from typing import Optional, Union
+from typing import Optional, Union, Tuple
 
 
 import hivemind
 import hivemind
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
 
-from src.bloom import BloomForCausalLM, DistributedBloomConfig
+from src.bloom import BloomForYou, DistributedBloomConfig
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
 
 
+import torch
+from hivemind import use_hivemind_log_handler
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
+
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM):
-    """BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+class DistributedBloomForYou(BloomForYou):
+    """BloomModel, but all transformer layers are hosted by the swarm"""
 
 
     def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
     def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
@@ -28,6 +33,7 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
         if 'initial_peers' not in kwargs:
         if 'initial_peers' not in kwargs:
             raise ValueError("Please specify initial_peers=...")
             raise ValueError("Please specify initial_peers=...")
+        
         dht = hivemind.DHT(
         dht = hivemind.DHT(
             initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
             initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
             start=True)
             start=True)
@@ -40,10 +46,84 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
 
 
         config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
         config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
         model = cls(config, dht, prefix)
         model = cls(config, dht, prefix)
-        model.load_state_dict(_load_state_dict(
+        model.transformer.load_state_dict(_load_state_dict(
             pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
             pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
-        ), strict=True)
+        ), strict=True) 
         return model
         return model
 
 
 
 
+class DistributedBloomForCausalLM(DistributedBloomForYou):
+    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+        # only last token for inputs_ids if past is defined in kwargs
+        if past:
+            input_ids = input_ids[:, -1].unsqueeze(-1)
+
+        attention_mask = kwargs.get("attention_mask", None)
+        position_ids = kwargs.get("position_ids", None)
+
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past:
+                position_ids = position_ids[:, -1].unsqueeze(-1)
+        else:
+            position_ids = None
+        return {
+            "input_ids": input_ids,
+            "past_key_values": past,
+            "use_cache": kwargs.get("use_cache"),
+            "position_ids": position_ids,
+            "attention_mask": attention_mask,
+        }
+
+    def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        transformer_outputs = self.transformer.forward(
+            input_ids=input_ids, return_dict=return_dict, **kwargs
+        )
+
+        # Switch dtype in case word_embeddings are fp16
+        word_embeddings = self.transformer.word_embeddings.weight.t()
+        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
+        lm_logits = (hidden_states @ word_embeddings).float()
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = lm_logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
 
 
+    @staticmethod
+    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+        """
+        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+        beam_idx at every generation step.
+        """
+        return tuple(
+            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+            for layer_past in past
+        )

+ 1 - 0
src/client/remote_sequential.py

@@ -42,6 +42,7 @@ class RemoteSequential(nn.Module):
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
 
 
         block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
         block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
+
         logger.debug(f"Remote block uids: {block_uids}")
         logger.debug(f"Remote block uids: {block_uids}")
         self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
         self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
 
 

+ 6 - 3
tests/test_full_model.py

@@ -24,9 +24,9 @@ if not MODEL_NAME:
 REF_NAME = os.environ.get("REF_NAME")
 REF_NAME = os.environ.get("REF_NAME")
 
 
 
 
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
     assert len(model.transformer.h) == model.config.n_layer
     assert len(model.transformer.h) == model.config.n_layer
 
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
     test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
@@ -52,6 +52,9 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
             recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
             recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
     recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
     recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
     recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
     recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-    recurrent_outputs = model.lm_head(recurrent_outputs)
+
+    dictionary = model.transformer.word_embeddings.weight.t()
+    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
+    recurrent_outputs = (recurrent_outputs @ dictionary).float()
     assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
     assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
     logger.info("Inference is consistent with forward")
     logger.info("Inference is consistent with forward")