test_remote_sequential.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch
  2. from hivemind import DHT, get_logger, use_hivemind_log_handler
  3. from src import RemoteSequential
  4. from src.client.remote_model import DistributedBloomConfig
  5. from test_utils import *
  6. use_hivemind_log_handler("in_root_logger")
  7. logger = get_logger(__file__)
  8. def test_remote_sequential():
  9. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  10. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  11. test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
  12. grad_proj = torch.randn(1, 5, config.hidden_size)
  13. sequential = RemoteSequential(config, dht)
  14. full_outputs = sequential(test_inputs)
  15. (full_outputs * grad_proj).sum().backward()
  16. assert test_inputs.grad is not None
  17. full_grad = test_inputs.grad.clone()
  18. test_inputs.grad.data.zero_()
  19. first_half = sequential[: config.n_layer // 2]
  20. second_half = sequential[config.n_layer // 2 :]
  21. assert len(first_half) + len(second_half) == len(sequential)
  22. assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
  23. for m in sequential, first_half, second_half:
  24. assert isinstance(repr(m), str)
  25. hidden = first_half(test_inputs)
  26. assert isinstance(hidden, torch.Tensor)
  27. assert hidden.shape == test_inputs.shape
  28. assert hidden.requires_grad
  29. second_half_outputs = second_half(hidden)
  30. assert torch.allclose(second_half_outputs, full_outputs)
  31. (second_half_outputs * grad_proj).sum().backward()
  32. assert torch.allclose(test_inputs.grad, full_grad)