test_block_exact_match.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import random
  2. from typing import Union
  3. import pytest
  4. import torch
  5. from transformers.models.bloom.configuration_bloom import BloomConfig
  6. from petals.bloom.block import WrappedBloomBlock
  7. from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
  8. from petals.client import DistributedBloomConfig, RemoteSequential
  9. from test_utils import *
  10. @pytest.mark.forked
  11. def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
  12. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  13. remote_sequential = RemoteSequential(config)
  14. for block_index in random.sample(range(config.n_layer), 3):
  15. remote_block = remote_sequential[block_index]
  16. inputs = torch.randn(1, 8, config.hidden_size)
  17. outputs_forward = remote_block(inputs)
  18. outputs_inference = []
  19. with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
  20. for i in range(inputs.shape[1]):
  21. outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
  22. # test that max length is respected
  23. with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
  24. sess.step(inputs[:, -1:, :])
  25. assert "Maximum length exceeded" in repr(exc_info.value)
  26. outputs_inference = torch.cat(outputs_inference, dim=1)
  27. ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
  28. (outputs_local,) = ref_block(inputs)
  29. assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
  30. assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
  31. def _old_load_pretrained_block(
  32. converted_model_name_or_path: str,
  33. block_index: int,
  34. torch_dtype: Union[torch.dtype, str] = "auto",
  35. ) -> WrappedBloomBlock:
  36. """Load the BLOOM block by directly initializing the weights.
  37. This test is used to check consistency with the previous implementation and can be removed in the future."""
  38. config = BloomConfig.from_pretrained(converted_model_name_or_path)
  39. block = WrappedBloomBlock(config)
  40. state_dict = _load_state_dict(
  41. converted_model_name_or_path,
  42. block_index,
  43. config,
  44. cache_dir=None,
  45. )
  46. if torch_dtype == "auto":
  47. with torch.no_grad():
  48. for name, param in block.named_parameters():
  49. assert name in state_dict, f"{name} not in state dict"
  50. param.data = param.data.to(state_dict[name].dtype)
  51. else:
  52. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  53. block = block.to(dtype=torch_dtype)
  54. block.load_state_dict(state_dict, strict=True)
  55. return block
  56. @pytest.mark.forked
  57. def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
  58. config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
  59. torch.random.manual_seed(0)
  60. inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
  61. block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
  62. ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
  63. outputs = block.forward(inputs)[0]
  64. outputs_ref = ref_block.forward(inputs)[0]
  65. assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)