Explorar o código

set requires_grad=False, lm_layer -> h @ word_embeddings, rm lm_layer from comverted_model

Dmitry Baranchuk %!s(int64=3) %!d(string=hai) anos
pai
achega
d969172208

+ 6 - 0
.gitignore

@@ -1,3 +1,9 @@
+server_configs
+converted_model
+*.id
+*.ipynb
+*.out
+
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]

+ 5 - 3
cli/convert_model.py

@@ -48,7 +48,8 @@ if __name__ == "__main__":
     config = transformers.AutoConfig.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
     )
-    model = transformers.AutoModelForCausalLM.from_pretrained(
+    # model = transformers.AutoModelForCausalLM.from_pretrained(
+    model = transformers.AutoModel.from_pretrained(    
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
     tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -59,7 +60,7 @@ if __name__ == "__main__":
     repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
     repo.git_pull()
 
-    transformer_blocks = model.transformer.h
+    transformer_blocks = model.h #transformer.h
     logger.info(
         f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
         f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
@@ -74,7 +75,8 @@ if __name__ == "__main__":
     logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
     repo.git_checkout(args.base_branch, create_branch_ok=True)
     with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
-        model.transformer.h = nn.ModuleList()
+        model.h = nn.ModuleList()
+        #model.transformer.h = nn.ModuleList()
         model.save_pretrained(".")
 
     logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")

+ 1 - 1
cli/run_local_servers.sh

@@ -98,7 +98,7 @@ do
     # Run server #
     ##############
 
-    tmux new-session -d -s "Server_${SERVER_ID}" bash deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
+    tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
 done
 
 

+ 17 - 21
src/bloom/model.py

@@ -155,8 +155,9 @@ class BloomModel(BloomPreTrainedModel):
         self.n_head = config.n_head
 
         # Embedding + LN Embedding
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
-
+        
+        # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
         # Transformer blocks
@@ -170,11 +171,18 @@ class BloomModel(BloomPreTrainedModel):
         # Initialize weights and apply final processing
         self.post_init()
 
+        # Forbid accumulate grads for embeddings and layernorm
+        self.set_requires_grad(False)
+
     def get_input_embeddings(self):
         return self.word_embeddings
 
     def set_input_embeddings(self, new_embeddings):
         self.word_embeddings = new_embeddings
+    
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad=value
 
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(
@@ -227,7 +235,7 @@ class BloomModel(BloomPreTrainedModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
 
         output_shape = input_shape + (hidden_states.size(-1),)
 
@@ -311,23 +319,9 @@ class BloomModel(BloomPreTrainedModel):
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForCausalLM(BloomPreTrainedModel):
+class BloomForCausalLM(BloomModel):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
-    def __init__(self, config):
-        super().__init__(config)
-        self.transformer = BloomModel(config)
-        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_output_embeddings(self):
-        return self.lm_head
-
-    def set_output_embeddings(self, new_embeddings):
-        self.lm_head = new_embeddings
-
     def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
         # only last token for inputs_ids if past is defined in kwargs
         if past:
@@ -381,7 +375,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
         """
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
-        transformer_outputs = self.transformer(
+        transformer_outputs = super().forward(
             input_ids,
             past_key_values=past_key_values,
             attention_mask=attention_mask,
@@ -393,9 +387,11 @@ class BloomForCausalLM(BloomPreTrainedModel):
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
         )
-        hidden_states = transformer_outputs[0]
 
-        lm_logits = self.lm_head(hidden_states)
+        # Switch dtype in case word_embeddings are fp16
+        word_embeddings = self.word_embeddings.weight.t()
+        hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
+        lm_logits = (hidden_states @ word_embeddings).float()
 
         loss = None
         if labels is not None:

+ 5 - 7
src/client/remote_model.py

@@ -5,7 +5,7 @@ from typing import Optional, Union
 import hivemind
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 
-from src.bloom import BloomForCausalLM, DistributedBloomConfig
+from src.bloom import BloomModel, BloomForCausalLM, DistributedBloomConfig
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
@@ -20,14 +20,15 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
     def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
-        assert len(self.transformer.h) == 0
+        assert len(self.h) == 0
         config.n_layer = n_layer
-        self.transformer.h = RemoteSequential(config, dht, prefix)
+        self.h = RemoteSequential(config, dht, prefix)
 
     @classmethod
     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
         if 'initial_peers' not in kwargs:
             raise ValueError("Please specify initial_peers=...")
+        
         dht = hivemind.DHT(
             initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
             start=True)
@@ -42,8 +43,5 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
         model = cls(config, dht, prefix)
         model.load_state_dict(_load_state_dict(
             pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
-        ), strict=True)
+        ), strict=True) 
         return model
-
-
-

+ 2 - 0
src/client/remote_sequential.py

@@ -41,7 +41,9 @@ class RemoteSequential(nn.Module):
         self.max_retries = max_retries
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
 
+        print(prefix)
         block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
+
         logger.debug(f"Remote block uids: {block_uids}")
         self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
 

+ 57 - 0
tests/test_full_model_new_model.py

@@ -0,0 +1,57 @@
+# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
+import os
+
+import torch
+import transformers
+from hivemind import use_hivemind_log_handler, get_logger
+
+from src.client.remote_model import DistributedBloomForCausalLM
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME")
+
+
+def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix="bloom6b3")
+    assert len(model.h) == model.config.n_layer
+
+    test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
+    parallel_outputs = model.forward(test_inputs).logits
+    assert torch.all(torch.isfinite(parallel_outputs))
+    logger.info("Forward outputs are finite")
+
+    if REF_NAME:
+        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+        # note: this creates a dummy mask to make the test compatible with older transformer versions
+        # prior to https://github.com/huggingface/transformers/pull/17837
+        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+    else:
+        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+
+    embs = model.word_embeddings(test_inputs)
+    norm_embs = model.word_embeddings_layernorm(embs.float())
+    recurrent_outputs = []
+    with model.h.inference_session() as sess:
+        for t in range(norm_embs.shape[1]):
+            recurrent_outputs.append(sess.step(norm_embs[:, t: t + 1, :]))
+    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
+    recurrent_outputs = model.ln_f(recurrent_outputs)
+    recurrent_outputs = (recurrent_outputs.to(embs.dtype) @ embs.t()).float()
+    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
+    logger.info("Inference is consistent with forward")