5
0

test_tensor_parallel.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import random
  2. import pytest
  3. import torch
  4. import transformers
  5. from tensor_parallel import TensorParallel
  6. from tensor_parallel.slicing_configs import get_bloom_config
  7. from petals.server.from_pretrained import load_pretrained_block
  8. from test_utils import MODEL_NAME
  9. @pytest.mark.forked
  10. @pytest.mark.parametrize("custom_config", [True, False])
  11. @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
  12. def test_tp_block(devices, custom_config):
  13. model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
  14. if model_config.model_type != "bloom":
  15. pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
  16. block_index = random.randint(0, 10)
  17. block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
  18. tp_config = None
  19. if custom_config:
  20. tp_config = get_bloom_config(model_config, devices)
  21. batch_size = 2
  22. prefix_length = 5
  23. test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
  24. test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
  25. test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
  26. test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
  27. grad_proj = torch.rand_like(test_inputs1)
  28. y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
  29. y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
  30. y_ref.backward(grad_proj)
  31. block_tp = TensorParallel(block, devices, config=tp_config)
  32. y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
  33. y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
  34. y_ours.backward(grad_proj)
  35. assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
  36. assert torch.allclose(y_ours, y_ref, atol=1e-5)
  37. assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
  38. assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)