test_remote_sequential.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import torch
  3. import transformers
  4. from hivemind import DHT, get_logger, use_hivemind_log_handler
  5. from src import RemoteSequential
  6. from src.client.remote_model import DistributedBloomForCausalLM, DistributedBloomConfig
  7. use_hivemind_log_handler("in_root_logger")
  8. logger = get_logger(__file__)
  9. INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
  10. if not INITIAL_PEERS:
  11. raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
  12. INITIAL_PEERS = INITIAL_PEERS.split()
  13. MODEL_NAME = os.environ.get("MODEL_NAME")
  14. if not MODEL_NAME:
  15. raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
  16. def test_remote_sequential():
  17. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  18. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  19. test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
  20. grad_proj = torch.randn(1, 5, config.hidden_size)
  21. sequential = RemoteSequential(config, dht)
  22. full_outputs = sequential(test_inputs)
  23. (full_outputs * grad_proj).sum().backward()
  24. assert test_inputs.grad is not None
  25. full_grad = test_inputs.grad.clone()
  26. test_inputs.grad.data.zero_()
  27. first_half = sequential[:config.n_layer // 2]
  28. second_half = sequential[config.n_layer // 2:]
  29. assert len(first_half) + len(second_half) == len(sequential)
  30. assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
  31. for m in sequential, first_half, second_half:
  32. assert isinstance(repr(m), str)
  33. hidden = first_half(test_inputs)
  34. assert isinstance(hidden, torch.Tensor)
  35. assert hidden.shape == test_inputs.shape
  36. assert hidden.requires_grad
  37. second_half_outputs = second_half(hidden)
  38. assert torch.allclose(second_half_outputs, full_outputs)
  39. (second_half_outputs * grad_proj).sum().backward()
  40. assert torch.allclose(test_inputs.grad, full_grad)