remote_sequential.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import logging
  2. from functools import partial
  3. from typing import Optional, Tuple
  4. import torch
  5. from hivemind import DHT, get_logger, use_hivemind_log_handler
  6. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  7. from torch import nn
  8. from src import DistributedBloomConfig
  9. from src.data_structures import UID_DELIMITER, RemoteModuleInfo
  10. from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
  11. use_hivemind_log_handler("in_root_logger")
  12. logger = get_logger(__file__)
  13. class RemoteSequential(nn.Sequential):
  14. """
  15. A sequence of transformer blocks hosted by the swarm.
  16. """
  17. def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
  18. logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
  19. if prefix.endswith(UID_DELIMITER):
  20. logger.warning(
  21. f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
  22. f"This will cause {self.__class__.__name__} to look for modules under "
  23. f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
  24. )
  25. super().__init__()
  26. self.config = config
  27. self.dht = dht
  28. self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
  29. self.prefix = prefix
  30. self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
  31. logger.debug(f"Remote block uids: {self.block_uids}")
  32. self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(
  33. dht.run_coroutine(
  34. partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
  35. return_future=False,
  36. )
  37. )
  38. self.max_retries = max_retries
  39. assert len(self.block_infos) == len(self.block_uids)
  40. for uid, info in zip(self.block_uids, self.block_infos):
  41. assert isinstance(info, (type(None), RemoteModuleInfo)), f"Unexpected dht entry for {uid}: {info}"
  42. assert info is not None, f"Found no active peers for block {uid}"
  43. assert isinstance(info.peer_ids, set), f"expected peer_ids to be a set, got {info.peer_ids}"
  44. assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
  45. assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
  46. def forward(self, inputs: torch.Tensor):
  47. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
  48. for block_index in range(self.config.n_layer):
  49. for retry_index in range(self.max_retries):
  50. try:
  51. block = self[block_index]
  52. (outputs,) = block(inputs)
  53. assert isinstance(outputs, torch.Tensor)
  54. assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
  55. inputs = outputs
  56. break
  57. except Exception as e:
  58. if retry_index == self.max_retries - 1:
  59. raise e
  60. else:
  61. logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
  62. return inputs
  63. def __getitem__(self, block_index: int):
  64. assert 0 <= block_index < self.config.n_layer
  65. (module,) = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
  66. return module
  67. def __iter__(self):
  68. for block_index in range(self.config.n_layer):
  69. yield self[block_index]