remote_sequential.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import contextlib
  3. import logging
  4. import random
  5. from typing import Optional, Union, List
  6. import torch
  7. from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
  8. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  9. from hivemind.moe.expert_uid import ExpertInfo
  10. from torch import nn
  11. import src
  12. from src.client.remote_block import RemoteTransformerBlock
  13. from src import RemoteTransformerBlockInferenceSession
  14. from src.client.sequence_manager import RemoteSequenceManager
  15. from src.data_structures import UID_DELIMITER, RemoteSpanInfo
  16. from src.dht_utils import _create_remote_modules_from_infos
  17. use_hivemind_log_handler("in_root_logger")
  18. logger = get_logger(__file__)
  19. class RemoteSequential(nn.Module):
  20. """
  21. A sequence of transformer blocks hosted by the swarm.
  22. """
  23. def __init__(
  24. self,
  25. config: src.DistributedBloomConfig,
  26. dht: DHT,
  27. prefix: str,
  28. max_retries: int = 3,
  29. p2p: Optional[P2P] = None,
  30. sequence_manager: Optional[RemoteSequenceManager] = None,
  31. ):
  32. logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
  33. if prefix.endswith(UID_DELIMITER):
  34. logger.warning(
  35. f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
  36. f"This will cause {self.__class__.__name__} to look for modules under "
  37. f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
  38. )
  39. super().__init__()
  40. self.config = config
  41. self.dht = dht
  42. self.prefix = prefix
  43. self.max_retries = max_retries
  44. self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
  45. block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
  46. if sequence_manager is None:
  47. logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
  48. self.sequence_manager = RemoteSequenceManager(dht, block_uids)
  49. self.is_subsequence = False
  50. else:
  51. assert isinstance(sequence_manager.block_uids, list)
  52. logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
  53. self.is_subsequence = self.sequence_manager.block_uids == block_uids
  54. def forward(self, inputs: torch.Tensor):
  55. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
  56. for block_index in range(self.config.n_layer):
  57. for retry_index in range(self.max_retries):
  58. try:
  59. block = self[block_index]
  60. (outputs,) = block(inputs)
  61. assert isinstance(outputs, torch.Tensor)
  62. assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
  63. inputs = outputs
  64. break
  65. except Exception as e:
  66. if retry_index == self.max_retries - 1:
  67. raise e
  68. else:
  69. logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
  70. return inputs
  71. def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
  72. assert isinstance(ix, (int, slice))
  73. if isinstance(ix, int):
  74. assert 0 <= ix < self.config.n_layer
  75. (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
  76. return module
  77. else:
  78. return RemoteSequential(
  79. self.config,
  80. self.dht,
  81. prefix=self.prefix,
  82. max_retries=self.max_retries,
  83. p2p=self.p2p,
  84. sequence_manager=self.sequence_manager[ix],
  85. )
  86. def __iter__(self):
  87. for block_index in range(self.config.n_layer):
  88. yield self[block_index]
  89. def __len__(self):
  90. return len(self.sequence_manager)
  91. def inference_session(self) -> RemoteSequentialInferenceSession:
  92. self.sequence_manager.update_()
  93. return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
  94. class RemoteSequentialInferenceSession:
  95. """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
  96. def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P):
  97. self.sequence_manager = sequence_manager
  98. self.p2p = p2p
  99. self.closed = False
  100. self.chosen_spans: List[RemoteSpanInfo] = []
  101. self.stack = contextlib.ExitStack()
  102. self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
  103. def __enter__(self):
  104. assert not self.closed and not self.chosen_spans
  105. self.stack.__enter__()
  106. # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
  107. self.chosen_spans.extend(self.sequence_manager.make_sequence())
  108. for chosen_span in self.chosen_spans:
  109. TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
  110. # TODO begin throwaway prototype code
  111. remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)
  112. _ = remote.info # TODO fix
  113. span_uids = self.sequence_manager.block_uids[current_block : chosen_span.end]
  114. remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
  115. self.inference_sessions.append(remote.inference_session())
  116. self.stack.enter_context(self.inference_sessions[-1])
  117. current_block = chosen_span.end
  118. # TODO end throwaway prototype code
  119. return self
  120. def step(self, inputs: torch.Tensor):
  121. assert not self.closed
  122. for session in self.inference_sessions:
  123. outputs = session.step(inputs)
  124. assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
  125. inputs = outputs
  126. return inputs
  127. def close(self, *exc_details):
  128. """Finish a given inference session, close the underlying connection"""
  129. if not self.closed:
  130. self.stack.__exit__(*exc_details or (None, None, None))
  131. self.inference_sessions.clear()
  132. self.closed = True
  133. def __exit__(self, *exc_details):
  134. self.close(*exc_details)
  135. def __del__(self):
  136. self.close()