sequence_manager.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from __future__ import annotations
  2. import random
  3. import threading
  4. from typing import List, Optional, Sequence, Tuple, Union
  5. from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
  6. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  9. from src.client.spending_policy import NoSpendingPolicy
  10. from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
  11. from src.dht_utils import get_remote_module_infos
  12. from src.server.handler import TransformerConnectionHandler
  13. use_hivemind_log_handler("in_root_logger")
  14. logger = get_logger(__file__)
  15. class RemoteSequenceManager:
  16. """
  17. Keeps and updates the meta-information about which peers host which blocks.
  18. In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
  19. """
  20. def __init__(
  21. self,
  22. dht: DHT,
  23. block_uids: Sequence[ModuleUID],
  24. p2p: P2P,
  25. max_retries: int = 3,
  26. timeout: float = 5,
  27. min_backoff: float = 1,
  28. ):
  29. assert len(block_uids) > 0, "Sequences must contain at least one block"
  30. self.dht, self.p2p = dht, p2p
  31. self.block_uids: List[ModuleUID] = list(block_uids)
  32. self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
  33. self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
  34. self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
  35. self.last_update_time: DHTExpiration = -float("inf")
  36. self.max_retries = max_retries
  37. self.timeout, self.min_backoff = timeout, min_backoff
  38. self._rpc_info = None
  39. self.lock_changes = threading.Lock()
  40. self.update_()
  41. for uid, info in zip(self.block_uids, self.block_infos):
  42. assert info is not None, f"Found no remote peers for block {uid}"
  43. assert self.spans_by_priority and self.spans_containing_block
  44. def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
  45. """
  46. Form a sequence of remote servers that collectively serve all consecutive layers
  47. :param start_index: optional index of the first module in a sequence, default = the first of block_uids
  48. :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
  49. """
  50. end_index = end_index if end_index is not None else len(self.block_uids)
  51. span_sequence = []
  52. current_index = start_index
  53. while current_index < end_index:
  54. candidate_spans = self.spans_containing_block[current_index]
  55. chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
  56. assert chosen_span.start <= current_index < chosen_span.end
  57. span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
  58. current_index = chosen_span.end
  59. return span_sequence
  60. def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
  61. """Get a RemoteSequenceManager for a sub-sequence of blocks"""
  62. assert isinstance(ix, (int, slice))
  63. if not isinstance(ix, slice):
  64. ix = slice(int(ix), int(ix) + 1, 1)
  65. with self.lock_changes:
  66. subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
  67. subseq.block_infos = self.block_infos[ix]
  68. subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
  69. subseq.last_update_time = self.last_update_time
  70. return subseq
  71. def update_(self):
  72. with self.lock_changes:
  73. self.update_block_infos_()
  74. self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
  75. def update_block_infos_(self):
  76. new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
  77. assert len(new_block_infos) == len(self.block_uids)
  78. for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
  79. if info is None:
  80. logger.warning(f"Found no block info for block {uid}")
  81. continue
  82. if not isinstance(info, RemoteModuleInfo):
  83. logger.warning(f"Unexpected dht entry type for {uid}: {info}")
  84. if not info.servers:
  85. logger.warning(f"Found no active peers for block {uid}")
  86. if info.uid != uid:
  87. logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
  88. self.block_infos[block_index] = info
  89. @staticmethod
  90. def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
  91. closed_spans = []
  92. active_spans = {}
  93. for block_index, info in enumerate(block_infos):
  94. if info is not None:
  95. for peer_id, server in info.servers.items():
  96. if server.state != ServerState.ONLINE:
  97. continue
  98. if peer_id not in active_spans:
  99. active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
  100. else: # peer_id in active_spans
  101. active_spans[peer_id].end = block_index + 1
  102. for peer_id in list(active_spans.keys()):
  103. if (
  104. info is None
  105. or peer_id not in info.servers
  106. or info.servers[peer_id].state != ServerState.ONLINE
  107. or block_index == len(block_infos) - 1
  108. ):
  109. closed_spans.append(active_spans.pop(peer_id))
  110. assert not active_spans, f"spans: {active_spans}"
  111. closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
  112. spans_containing_block = tuple(list() for _ in range(len(block_infos)))
  113. for span in closed_spans:
  114. for block_index in range(span.start, span.end):
  115. spans_containing_block[block_index].append(span)
  116. return closed_spans, spans_containing_block
  117. def __len__(self):
  118. return len(self.block_uids)
  119. @property
  120. def rpc_info(self):
  121. """Return the rpc_info queried from one of the servers that hold the first block"""
  122. if self._rpc_info is None:
  123. retries = 0
  124. for i in range(self.max_retries):
  125. try:
  126. self.update_()
  127. peer_id = random.choice(list(self.block_infos[0].servers.keys()))
  128. stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
  129. outputs = RemoteExpertWorker.run_coroutine(
  130. stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
  131. )
  132. self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
  133. break
  134. except Exception as e:
  135. retries += 1
  136. if retries >= self.max_retries:
  137. raise e
  138. else:
  139. logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
  140. return self._rpc_info