test_remote_sequential.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import pytest
  2. import torch
  3. from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler
  4. from hivemind.proto import runtime_pb2
  5. from test_utils import *
  6. from petals.bloom.from_pretrained import load_pretrained_block
  7. from petals.client import RemoteSequential
  8. from petals.client.remote_model import DistributedBloomConfig
  9. use_hivemind_log_handler("in_root_logger")
  10. logger = get_logger(__file__)
  11. @pytest.mark.forked
  12. def test_remote_sequential():
  13. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  14. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  15. test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
  16. grad_proj = torch.randn(1, 5, config.hidden_size)
  17. sequential = RemoteSequential(config, dht)
  18. full_outputs = sequential(test_inputs)
  19. (full_outputs * grad_proj).sum().backward()
  20. assert test_inputs.grad is not None
  21. full_grad = test_inputs.grad.clone()
  22. test_inputs.grad.data.zero_()
  23. first_half = sequential[: config.n_layer // 2]
  24. second_half = sequential[config.n_layer // 2 :]
  25. assert len(first_half) + len(second_half) == len(sequential)
  26. assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
  27. for m in sequential, first_half, second_half:
  28. assert isinstance(repr(m), str)
  29. hidden = first_half(test_inputs)
  30. assert isinstance(hidden, torch.Tensor)
  31. assert hidden.shape == test_inputs.shape
  32. assert hidden.requires_grad
  33. second_half_outputs = second_half(hidden)
  34. assert torch.allclose(second_half_outputs, full_outputs)
  35. (second_half_outputs * grad_proj).sum().backward()
  36. assert torch.allclose(test_inputs.grad, full_grad)
  37. @pytest.mark.forked
  38. def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
  39. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  40. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  41. remote_sequential = RemoteSequential(config, dht)
  42. inputs = torch.randn(batch_size, seq_len, config.hidden_size)
  43. output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
  44. input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
  45. intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
  46. input_prompts = input_prompts.detach().requires_grad_(True)
  47. intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
  48. inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
  49. assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
  50. outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
  51. (outputs * output_proj).sum().backward()
  52. assert intermediate_prompts.grad is not None
  53. input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
  54. intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
  55. assert input_prompts_ref.grad is None
  56. assert intermediate_prompts_ref.grad is None
  57. outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
  58. for block_index in range(config.n_layer):
  59. block_prompt = intermediate_prompts_ref[block_index]
  60. outputs_ref[:, : block_prompt.shape[1]] += block_prompt
  61. block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
  62. (outputs_ref,) = block(outputs_ref)
  63. assert torch.allclose(outputs_ref, outputs)
  64. (outputs_ref * output_proj).sum().backward()
  65. assert input_prompts_ref.grad is not None
  66. assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
  67. assert intermediate_prompts_ref.grad is not None
  68. assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)