remote_sequential.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from __future__ import annotations
  2. import contextlib
  3. import logging
  4. import random
  5. from typing import Optional, Union
  6. import torch
  7. from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
  8. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  9. from hivemind.moe.expert_uid import ExpertInfo
  10. from torch import nn
  11. import src
  12. from src.client.remote_block import RemoteTransformerBlock
  13. from src.client.sequence_manager import RemoteSequenceManager
  14. from src.data_structures import UID_DELIMITER
  15. from src.dht_utils import _create_remote_modules_from_infos
  16. use_hivemind_log_handler("in_root_logger")
  17. logger = get_logger(__file__)
  18. class RemoteSequential(nn.Module):
  19. """
  20. A sequence of transformer blocks hosted by the swarm.
  21. """
  22. def __init__(
  23. self,
  24. config: src.DistributedBloomConfig,
  25. dht: DHT,
  26. prefix: str,
  27. max_retries: int = 3,
  28. p2p: Optional[P2P] = None,
  29. sequence_manager: Optional[RemoteSequenceManager] = None,
  30. ):
  31. logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
  32. if prefix.endswith(UID_DELIMITER):
  33. logger.warning(
  34. f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
  35. f"This will cause {self.__class__.__name__} to look for modules under "
  36. f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
  37. )
  38. super().__init__()
  39. self.config = config
  40. self.dht = dht
  41. self.prefix = prefix
  42. self.max_retries = max_retries
  43. self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
  44. block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
  45. if sequence_manager is None:
  46. logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
  47. self.sequence_manager = RemoteSequenceManager(dht, block_uids)
  48. self.is_subsequence = False
  49. else:
  50. assert isinstance(sequence_manager.block_uids, list)
  51. logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
  52. self.is_subsequence = self.sequence_manager.block_uids == block_uids
  53. def forward(self, inputs: torch.Tensor):
  54. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
  55. for block_index in range(self.config.n_layer):
  56. for retry_index in range(self.max_retries):
  57. try:
  58. block = self[block_index]
  59. (outputs,) = block(inputs)
  60. assert isinstance(outputs, torch.Tensor)
  61. assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
  62. inputs = outputs
  63. break
  64. except Exception as e:
  65. if retry_index == self.max_retries - 1:
  66. raise e
  67. else:
  68. logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
  69. return inputs
  70. def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
  71. assert isinstance(ix, (int, slice))
  72. if isinstance(ix, int):
  73. assert 0 <= ix < self.config.n_layer
  74. (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
  75. return module
  76. else:
  77. return RemoteSequential(
  78. self.config,
  79. self.dht,
  80. prefix=self.prefix,
  81. max_retries=self.max_retries,
  82. p2p=self.p2p,
  83. sequence_manager=self.sequence_manager[ix],
  84. )
  85. def __iter__(self):
  86. for block_index in range(self.config.n_layer):
  87. yield self[block_index]
  88. def __len__(self):
  89. return len(self.sequence_manager)
  90. def inference_session(self) -> RemoteSequentialInferenceSession:
  91. self.sequence_manager.update_()
  92. return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
  93. class RemoteSequentialInferenceSession:
  94. """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
  95. def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
  96. self.remote_sequence_info = remote_sequence_info
  97. self.p2p = p2p
  98. self.closed = False
  99. self.stack = contextlib.ExitStack()
  100. self.active_sessions = []
  101. def __enter__(self):
  102. assert not self.closed
  103. self.stack.__enter__()
  104. # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
  105. current_block = 0
  106. while current_block != len(self.remote_sequence_info):
  107. candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
  108. chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
  109. assert chosen_span.start <= current_block < chosen_span.end
  110. # TODO begin throwaway prototype code
  111. remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
  112. _ = remote.info # TODO fix
  113. span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
  114. remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
  115. self.active_sessions.append(remote.inference_session())
  116. self.stack.enter_context(self.active_sessions[-1])
  117. current_block = chosen_span.end
  118. # TODO end throwaway prototype code
  119. return self
  120. def step(self, inputs: torch.Tensor):
  121. assert not self.closed
  122. for session in self.active_sessions:
  123. outputs = session.step(inputs)
  124. assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
  125. inputs = outputs
  126. return inputs
  127. def close(self, *exc_details):
  128. """Finish a given inference session, close the underlying connection"""
  129. if not self.closed:
  130. self.stack.__exit__(*exc_details or (None, None, None))
  131. self.active_sessions.clear()
  132. self.closed = True
  133. def __exit__(self, *exc_details):
  134. self.close(*exc_details)
  135. def __del__(self):
  136. self.close()