فهرست منبع

minimize peak CPU memory usage

justheuristic 3 سال پیش
والد
کامیت
dc277b34da
2فایلهای تغییر یافته به همراه9 افزوده شده و 6 حذف شده
  1. 1 1
      src/bloom/ops.py
  2. 8 5
      tests/test_full_model.py

+ 1 - 1
src/bloom/ops.py

@@ -101,7 +101,7 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor)
         attention_mask: ([`torch.tensor`], *required*):
             attention mask to pre-process
     """
-    assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
+    assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
     unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
     # ^-- [batch, max_len], values correspond to element indices after removing padding
     # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0

+ 8 - 5
tests/test_full_model.py

@@ -2,6 +2,7 @@ import pytest
 import torch
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
+
 from test_utils import *
 
 from src.client.remote_model import DistributedBloomForCausalLM
@@ -13,13 +14,14 @@ logger = get_logger(__file__)
 @pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS,
+                                                        low_cpu_mem_usage=True, torch_dtype=torch.float32)
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 
-    with torch.no_grad():
+    with torch.inference_mode():
         parallel_outputs = model.forward(test_inputs).logits
         assert torch.all(torch.isfinite(parallel_outputs))
         logger.info("Forward outputs are finite")
@@ -36,14 +38,15 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
         logger.info("Inference is consistent with forward")
 
-        del model, recurrent_outputs
+        del model, embs, recurrent_outputs
 
         if REF_NAME:
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            ref_model = transformers.BloomForCausalLM.from_pretrained(
+                REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)
             dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
             # note: this creates a dummy mask to make the test compatible with older transformer versions
             # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
             assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
             logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
             del ref_model, ref_outputs, dummy_mask