remote_sequential.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from __future__ import annotations
  2. from typing import Optional, Union
  3. import torch
  4. from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
  5. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  6. from torch import nn
  7. import src
  8. from src.client.inference_session import InferenceSession
  9. from src.client.sequence_manager import RemoteSequenceManager
  10. from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
  11. from src.data_structures import UID_DELIMITER
  12. from src.utils.misc import DUMMY
  13. use_hivemind_log_handler("in_root_logger")
  14. logger = get_logger(__file__)
  15. class RemoteSequential(nn.Module):
  16. """
  17. A sequence of transformer blocks hosted by the swarm.
  18. """
  19. def __init__(
  20. self,
  21. config: src.DistributedBloomConfig,
  22. dht: DHT,
  23. dht_prefix: Optional[str] = None,
  24. p2p: Optional[P2P] = None,
  25. sequence_manager: Optional[RemoteSequenceManager] = None,
  26. ):
  27. super().__init__()
  28. self.config = config
  29. self.dht = dht
  30. self.dht_prefix = dht_prefix or config.dht_prefix
  31. self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
  32. num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
  33. block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
  34. if sequence_manager is None:
  35. logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
  36. self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
  37. self.is_subsequence = False
  38. else:
  39. logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
  40. self.sequence_manager = sequence_manager
  41. assert isinstance(sequence_manager.block_uids, list)
  42. self.is_subsequence = self.sequence_manager.block_uids != block_uids
  43. def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
  44. outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
  45. return outputs
  46. def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
  47. assert isinstance(ix, (int, slice))
  48. if isinstance(ix, int):
  49. return RemoteTransformerBlock(
  50. self.config,
  51. self.dht,
  52. dht_prefix=self.dht_prefix,
  53. p2p=self.p2p,
  54. sequence_manager=self.sequence_manager[ix],
  55. )
  56. else:
  57. return RemoteSequential(
  58. self.config,
  59. self.dht,
  60. dht_prefix=self.dht_prefix,
  61. p2p=self.p2p,
  62. sequence_manager=self.sequence_manager[ix],
  63. )
  64. def __iter__(self):
  65. for block_index in range(len(self)):
  66. yield self[block_index]
  67. def __len__(self):
  68. return len(self.sequence_manager)
  69. def inference_session(self, **kwargs) -> InferenceSession:
  70. self.sequence_manager.update_()
  71. return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
  72. def extra_repr(self) -> str:
  73. return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
  74. class RemoteTransformerBlock(RemoteSequential):
  75. """Single transformer block hosted by swarm
  76. This class is deprecated and kept for backward compatibility.
  77. It will be removed soon in favor of using ``RemoteSequential`` directly.
  78. """
  79. def __init__(self, *args, **kwargs):
  80. super().__init__(*args, **kwargs)
  81. assert len(self) == 1, "Remote Block is a sequence size 1"
  82. def extra_repr(self):
  83. return f"{self.sequence_manager.block_uids[0]}"