test_remote_sequential.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import pytest
  2. import torch
  3. import torch.nn.functional as F
  4. from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
  5. from hivemind.proto import runtime_pb2
  6. from test_utils import *
  7. from petals.bloom.from_pretrained import load_pretrained_block
  8. from petals.client import RemoteSequenceManager, RemoteSequential
  9. from petals.client.remote_model import DistributedBloomConfig
  10. from petals.data_structures import UID_DELIMITER
  11. logger = get_logger(__file__)
  12. @pytest.mark.forked
  13. def test_remote_sequential():
  14. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  15. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  16. test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
  17. test_attention_mask = torch.ones((1, 5))
  18. grad_proj = torch.randn(1, 5, config.hidden_size)
  19. sequential = RemoteSequential(config, dht)
  20. full_outputs = sequential(test_inputs, test_attention_mask)
  21. (full_outputs * grad_proj).sum().backward()
  22. assert test_inputs.grad is not None
  23. full_grad = test_inputs.grad.clone()
  24. test_inputs.grad.data.zero_()
  25. first_half = sequential[: config.n_layer // 2]
  26. second_half = sequential[config.n_layer // 2 :]
  27. assert len(first_half) + len(second_half) == len(sequential)
  28. assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
  29. for m in sequential, first_half, second_half:
  30. assert isinstance(repr(m), str)
  31. hidden = first_half(test_inputs, test_attention_mask)
  32. assert isinstance(hidden, torch.Tensor)
  33. assert hidden.shape == test_inputs.shape
  34. assert hidden.requires_grad
  35. second_half_outputs = second_half(hidden, test_attention_mask)
  36. assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
  37. (second_half_outputs * grad_proj).sum().backward()
  38. assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
  39. # test RemoteSequential with lossy compression
  40. block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
  41. lossy_sequential = RemoteSequential(
  42. config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True)
  43. )
  44. test_inputs.grad = None
  45. approx_outputs = lossy_sequential(test_inputs, test_attention_mask)
  46. (approx_outputs * grad_proj).sum().backward()
  47. assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
  48. assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
  49. assert abs(approx_outputs - full_outputs).mean() < 0.01
  50. absmax = abs(full_grad).max()
  51. assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
  52. class DummyCustomSequenceManager(RemoteSequenceManager):
  53. """A sequence manager that compresses inputs/outputs during forward and backward pass."""
  54. @property
  55. def rpc_info(self):
  56. rpc_info = super().rpc_info
  57. dims = (2048, 1024)
  58. compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
  59. rpc_info["forward_schema"] = (compressed_input_schema, compressed_input_schema), dict() # (args, kwargs)
  60. return rpc_info
  61. def get_request_metadata(self, protocol: str, *args, **kwargs):
  62. metadata = super().get_request_metadata(protocol, *args, **kwargs)
  63. if protocol == "rpc_forward":
  64. metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
  65. elif protocol == "rpc_backward":
  66. metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
  67. return metadata
  68. @pytest.mark.forked
  69. def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
  70. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  71. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  72. remote_sequential = RemoteSequential(config, dht)
  73. inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
  74. attention_mask = torch.ones((batch_size, seq_len + pre_seq_len))
  75. output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
  76. input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
  77. intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
  78. input_prompts = input_prompts.detach().requires_grad_(True)
  79. intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
  80. inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
  81. assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
  82. outputs = remote_sequential(inputs_with_prompts, attention_mask, prompts=intermediate_prompts)
  83. (outputs * output_proj).sum().backward()
  84. assert intermediate_prompts.grad is not None
  85. input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
  86. intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
  87. assert input_prompts_ref.grad is None
  88. assert intermediate_prompts_ref.grad is None
  89. outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
  90. for block_index in range(config.n_layer):
  91. block_prompt = intermediate_prompts_ref[block_index]
  92. outputs_ref[:, : block_prompt.shape[1]] += block_prompt
  93. block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
  94. (outputs_ref,) = block(outputs_ref, attention_mask)
  95. assert torch.allclose(outputs_ref, outputs, atol=1e-3)
  96. (outputs_ref * output_proj).sum().backward()
  97. assert input_prompts_ref.grad is not None
  98. assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
  99. assert intermediate_prompts_ref.grad is not None
  100. assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)