Ver código fonte

Speed up loading blocks using init with meta weights (#285)

* Init WrappedBloomBlock with meta weights

---------

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Max Ryabinin 2 anos atrás
pai
commit
793726b041

+ 2 - 1
pyproject.toml

@@ -14,4 +14,5 @@ profile = "black"
 line_length = 120
 combine_as_imports = true
 combine_star = true
-known_local_folder = ["tests", "cli"]
+known_local_folder = ["tests", "cli"]
+known_first_party = ["test_utils"]

+ 16 - 10
src/petals/bloom/from_pretrained.py

@@ -13,6 +13,8 @@ import time
 from typing import Optional, OrderedDict, Union
 
 import torch
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
 from hivemind.utils.logging import get_logger
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.models.bloom.configuration_bloom import BloomConfig
@@ -38,13 +40,16 @@ def load_pretrained_block(
     max_disk_space: Optional[int] = None,
 ) -> WrappedBloomBlock:
     """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
+    assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
 
-    block = WrappedBloomBlock(config)
+    with init_empty_weights():
+        block = WrappedBloomBlock(config)
+
     state_dict = _load_state_dict(
         converted_model_name_or_path,
         block_index,
@@ -54,16 +59,17 @@ def load_pretrained_block(
         max_disk_space=max_disk_space,
     )
 
-    if torch_dtype == "auto":
-        with torch.no_grad():
-            for name, param in block.named_parameters():
-                assert name in state_dict, f"{name} not in state dict"
-                param.data = param.data.to(state_dict[name].dtype)
-    else:
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        block = block.to(dtype=torch_dtype)
-
+    # dummy load, check that keys match
     report = block.load_state_dict(state_dict, strict=True)
+    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
+
+    for param_name, _ in block.named_parameters():
+        assert param_name in state_dict, f"{param_name} not in state dict"
+        param = state_dict[param_name]
+        if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+            param = param.to(torch_dtype)
+        set_module_tensor_to_device(block, param_name, "cpu", value=param)
+
     logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
     return block
 

+ 1 - 1
src/petals/server/block_utils.py

@@ -30,7 +30,7 @@ def get_block_size(
             dtype is not None and load_in_8bit is not None
         ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
 
-    with init_empty_weights():
+    with init_empty_weights(include_buffers=True):
         block = WrappedBloomBlock(config)
         n_params = sum(param.numel() for param in block.parameters())
 

+ 2 - 2
tests/test_aux_functions.py

@@ -1,9 +1,9 @@
 import pytest
 import torch
-from test_utils import MODEL_NAME
 
 from petals.client import DistributedBloomConfig
-from petals.server.throughput import measure_compute_rps, measure_network_rps
+from petals.server.throughput import measure_compute_rps
+from test_utils import MODEL_NAME
 
 
 @pytest.mark.forked

+ 49 - 2
tests/test_block_exact_match.py

@@ -1,15 +1,18 @@
 import random
+from typing import Union
 
 import hivemind
 import pytest
 import torch
-from test_utils import *
+from transformers.models.bloom.configuration_bloom import BloomConfig
 
-from petals.bloom.from_pretrained import load_pretrained_block
+from petals.bloom.block import WrappedBloomBlock
+from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
 from petals.client import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteTransformerBlock
 from petals.data_structures import UID_DELIMITER
 from petals.dht_utils import get_remote_module
+from test_utils import *
 
 
 @pytest.mark.forked
@@ -41,3 +44,47 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
 
         assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
         assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
+
+
+def _old_load_pretrained_block(
+    converted_model_name_or_path: str,
+    block_index: int,
+    torch_dtype: Union[torch.dtype, str] = "auto",
+) -> WrappedBloomBlock:
+    """Load the BLOOM block by directly initializing the weights.
+    This test is used to check consistency with the previous implementation and can be removed in the future."""
+    config = BloomConfig.from_pretrained(converted_model_name_or_path)
+
+    block = WrappedBloomBlock(config)
+    state_dict = _load_state_dict(
+        converted_model_name_or_path,
+        block_index,
+        config,
+        cache_dir=None,
+    )
+
+    if torch_dtype == "auto":
+        with torch.no_grad():
+            for name, param in block.named_parameters():
+                assert name in state_dict, f"{name} not in state dict"
+                param.data = param.data.to(state_dict[name].dtype)
+    else:
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        block = block.to(dtype=torch_dtype)
+
+    block.load_state_dict(state_dict, strict=True)
+    return block
+
+
+@pytest.mark.forked
+def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    torch.random.manual_seed(0)
+    inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
+
+    block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
+    ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
+
+    outputs = block.forward(inputs)[0]
+    outputs_ref = ref_block.forward(inputs)[0]
+    assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)

+ 1 - 1
tests/test_chained_calls.py

@@ -7,12 +7,12 @@
 import hivemind
 import pytest
 import torch
-from test_utils import *
 
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteSequential
 from petals.dht_utils import get_remote_sequence
+from test_utils import *
 
 
 @pytest.mark.forked

+ 1 - 1
tests/test_full_model.py

@@ -2,11 +2,11 @@ import pytest
 import torch
 import transformers
 from hivemind import get_logger
-from test_utils import *
 from transformers.generation import BeamSearchScorer
 from transformers.models.bloom import BloomForCausalLM
 
 from petals.client.remote_model import DistributedBloomForCausalLM
+from test_utils import *
 
 logger = get_logger(__name__)
 

+ 2 - 2
tests/test_remote_sequential.py

@@ -1,14 +1,14 @@
 import pytest
 import torch
 import torch.nn.functional as F
-from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
+from hivemind import DHT, BatchTensorDescriptor, get_logger
 from hivemind.proto import runtime_pb2
-from test_utils import *
 
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
+from test_utils import *
 
 logger = get_logger(__name__)
 

+ 1 - 1
tests/test_sequence_manager.py

@@ -4,11 +4,11 @@ import time
 import pytest
 import torch
 from hivemind import DHT, get_logger
-from test_utils import *
 
 from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
+from test_utils import *
 
 logger = get_logger(__name__)
 

+ 1 - 1
tests/test_server_stats.py

@@ -3,12 +3,12 @@ import time
 import hivemind
 import pytest
 import torch
-from test_utils import *
 
 from petals.client import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
 from petals.dht_utils import get_remote_sequence
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
+from test_utils import *
 
 
 @pytest.mark.forked

+ 1 - 1
tests/test_tensor_parallel.py

@@ -5,9 +5,9 @@ import torch
 import transformers
 from tensor_parallel import TensorParallel
 from tensor_parallel.slicing_configs import get_bloom_config
-from test_utils import MODEL_NAME
 
 from petals.bloom.from_pretrained import load_pretrained_block
+from test_utils import MODEL_NAME
 
 
 @pytest.mark.forked