소스 검색

fix "some weights of DistributedBloomForCausalLM were not initialized from the model checkpoint at bloom-testing/test-bloomd-350m and are newly initialized: ['lm_head.word_embeddings.weight']""

justheuristic 3 년 전
부모
커밋
9f72a2417b
4개의 변경된 파일70개의 추가작업 그리고 27개의 파일을 삭제
  1. 8 2
      .github/workflows/run-tests.yaml
  2. 16 0
      src/bloom/model.py
  3. 23 6
      src/client/remote_model.py
  4. 23 19
      tests/test_full_model.py

+ 8 - 2
.github/workflows/run-tests.yaml

@@ -64,15 +64,21 @@ jobs:
         run: |
           python -m cli.run_server --converted_model_name_or_path bloom-testing/test-bloomd-350m --block_indices 0:12 \
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
-          SERVER1_PID=$$
+          SERVER1_PID=$!
           
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
           
+          python -m cli.run_server --converted_model_name_or_path bloom-testing/test-bloomd-350m --block_indices 12:24 \
+            --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
+          SERVER2_PID=$!
+
           sleep 30  # wait for server to download layers
           
           # test individual blocks
           export PYTHONPATH=. REF_NAME=bloom-testing/test-bloomd-350m
           BLOCK_UID=bloom-testing/test-bloomd-350m.0 REF_INDEX=0 pytest tests/test_block_exact_match.py
-          BLOCK_UID=bloom-testing/test-bloomd-350m.4 REF_INDEX=4 pytest tests/test_block_exact_match.py
+          BLOCK_UID=bloom-testing/test-bloomd-350m.19 REF_INDEX=19 pytest tests/test_block_exact_match.py
+          
+          kill -s SIGINT $SERVER1_PID $SERVER2_PID
           echo "Done!"

+ 16 - 0
src/bloom/model.py

@@ -447,6 +447,22 @@ class LMHead(nn.Module):
         self.word_embeddings = word_embeddings
         self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
 
+    @property
+    def in_features(self) -> int:
+        return self.word_embeddings.num_embeddings
+
+    @property
+    def out_features(self) -> int:
+        return self.word_embeddings.embedding_dim
+
+    @property
+    def weight(self):
+        return self.word_embeddings.weight
+
+    @property
+    def bias(self):
+        return None
+
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
 

+ 23 - 6
src/client/remote_model.py

@@ -16,7 +16,6 @@ from src.bloom.model import (
     LMHead,
 )
 from src.client.remote_sequential import RemoteSequential
-from src.data_structures import UID_DELIMITER
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -64,6 +63,12 @@ class DistributedBloomModel(BloomModel):
         for p in self.parameters():
             p.requires_grad = value
 
+    def forward(self, *args, use_cache=None, **kwargs):
+        if use_cache:
+            raise ValueError("Distributed forward does not support use_cache; for efficient cache-aware generation, "
+                             "please use model.transformer.inference_session() or model.generate(...)")
+        return super().forward(*args, use_cache=False, **kwargs)
+
 
 class DistributedBloomPrefix(DistributedBloomModel):
     """DistributedBloomModel with prefix tokens for prompt tuning"""
@@ -131,7 +136,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
-    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    """ Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     config_class = DistributedBloomConfig
 
@@ -146,11 +151,23 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
         # Initialize weights and apply final processing
         self.post_init()
 
-    def get_output_embeddings(self):
-        return self.lm_head.word_embeddings
+    def get_input_embeddings(self):
+        return self.transformer.word_embeddings
 
-    def set_output_embeddings(self, new_embeddings):
-        self.lm_head.word_embeddings.weight = new_embeddings.weight
+    def get_output_embeddings(self):
+        if self.config.tie_word_embeddings:
+            return None
+        return self.lm_head
+
+    def set_input_embeddings(self, new_embeddings: nn.Embedding):
+        assert isinstance(new_embeddings, nn.Embedding)
+        self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
+        assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
+
+    def set_output_embeddings(self, new_lm_head: nn.Linear):
+        with torch.no_grad():
+            self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
+            self.lm_head.bias[...] = new_lm_head.bias
 
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):

+ 23 - 19
tests/test_full_model.py

@@ -24,9 +24,10 @@ if not MODEL_NAME:
 REF_NAME = os.environ.get("REF_NAME")
 
 
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
@@ -35,26 +36,29 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     logger.info("Forward outputs are finite")
 
     if REF_NAME:
-        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-        # note: this creates a dummy mask to make the test compatible with older transformer versions
-        # prior to https://github.com/huggingface/transformers/pull/17837
-        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+        with torch.no_grad():
+            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+            # note: this creates a dummy mask to make the test compatible with older transformer versions
+            # prior to https://github.com/huggingface/transformers/pull/17837
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+            del ref_model, ref_outputs
     else:
         logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
 
-    embs = model.transformer.word_embeddings(test_inputs)
-    embs = model.transformer.word_embeddings_layernorm(embs)
-    recurrent_outputs = []
-    with model.transformer.h.inference_session() as sess:
-        for t in range(embs.shape[1]):
-            recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
-    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
-    recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-    dictionary = model.transformer.word_embeddings.weight.t()
-    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-    recurrent_outputs = (recurrent_outputs @ dictionary).float()
+    with torch.inference_mode():
+        embs = model.transformer.word_embeddings(test_inputs)
+        embs = model.transformer.word_embeddings_layernorm(embs)
+        recurrent_outputs = []
+        with model.transformer.h.inference_session() as sess:
+            for t in range(embs.shape[1]):
+                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+        recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
+        recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
+
+        dictionary = model.transformer.word_embeddings.weight.t()
+        recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
+        recurrent_outputs = (recurrent_outputs @ dictionary).float()
     assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
     logger.info("Inference is consistent with forward")