|
@@ -1,15 +1,18 @@
|
|
|
import random
|
|
|
+from typing import Union
|
|
|
|
|
|
import hivemind
|
|
|
import pytest
|
|
|
import torch
|
|
|
-from test_utils import *
|
|
|
+from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
|
|
|
-from petals.bloom.from_pretrained import load_pretrained_block
|
|
|
+from petals.bloom.block import WrappedBloomBlock
|
|
|
+from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
|
|
|
from petals.client import DistributedBloomConfig
|
|
|
from petals.client.remote_sequential import RemoteTransformerBlock
|
|
|
from petals.data_structures import UID_DELIMITER
|
|
|
from petals.dht_utils import get_remote_module
|
|
|
+from test_utils import *
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
@@ -41,3 +44,47 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
|
|
|
|
|
|
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
|
|
|
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
|
|
|
+
|
|
|
+
|
|
|
+def _old_load_pretrained_block(
|
|
|
+ converted_model_name_or_path: str,
|
|
|
+ block_index: int,
|
|
|
+ torch_dtype: Union[torch.dtype, str] = "auto",
|
|
|
+) -> WrappedBloomBlock:
|
|
|
+ """Load the BLOOM block by directly initializing the weights.
|
|
|
+ This test is used to check consistency with the previous implementation and can be removed in the future."""
|
|
|
+ config = BloomConfig.from_pretrained(converted_model_name_or_path)
|
|
|
+
|
|
|
+ block = WrappedBloomBlock(config)
|
|
|
+ state_dict = _load_state_dict(
|
|
|
+ converted_model_name_or_path,
|
|
|
+ block_index,
|
|
|
+ config,
|
|
|
+ cache_dir=None,
|
|
|
+ )
|
|
|
+
|
|
|
+ if torch_dtype == "auto":
|
|
|
+ with torch.no_grad():
|
|
|
+ for name, param in block.named_parameters():
|
|
|
+ assert name in state_dict, f"{name} not in state dict"
|
|
|
+ param.data = param.data.to(state_dict[name].dtype)
|
|
|
+ else:
|
|
|
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
+ block = block.to(dtype=torch_dtype)
|
|
|
+
|
|
|
+ block.load_state_dict(state_dict, strict=True)
|
|
|
+ return block
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
|
|
|
+ config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
|
|
+ torch.random.manual_seed(0)
|
|
|
+ inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
|
|
|
+
|
|
|
+ block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
|
|
|
+ ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
|
|
|
+
|
|
|
+ outputs = block.forward(inputs)[0]
|
|
|
+ outputs_ref = ref_block.forward(inputs)[0]
|
|
|
+ assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)
|