sequence_manager.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import annotations
  2. import threading
  3. from typing import List, Optional, Sequence, Tuple, Union
  4. from hivemind import DHT, DHTExpiration
  5. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  6. from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
  7. from src.dht_utils import get_remote_module_infos
  8. use_hivemind_log_handler("in_root_logger")
  9. logger = get_logger(__file__)
  10. class RemoteSequenceManager:
  11. """Keeps and updates the meta-information about which peers host which blocks"""
  12. dht: DHT
  13. block_uids: List[ModuleUID]
  14. block_infos: List[Optional[RemoteModuleInfo]]
  15. spans_by_priority: List[RemoteSpanInfo] # sorted from best to worst
  16. spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
  17. last_update_time: DHTExpiration
  18. lock_changes: threading.Lock
  19. def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
  20. self.dht = dht
  21. self.block_uids = list(block_uids)
  22. self.block_infos = [None] * len(self.block_uids)
  23. self.spans_by_priority = []
  24. self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
  25. self.last_update_time = -float("inf")
  26. self.lock_changes = threading.Lock()
  27. self.update_()
  28. for uid, info in zip(self.block_uids, self.block_infos):
  29. assert info is not None, f"Found no remote peers for block {uid}"
  30. assert self.spans_by_priority and self.spans_containing_block
  31. def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
  32. """Get a RemoteSequenceManager for a sub-sequence of blocks"""
  33. assert isinstance(ix, (int, slice))
  34. if not isinstance(ix, slice):
  35. ix = slice(int(ix), int(ix) + 1, 1)
  36. with self.lock_changes:
  37. subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
  38. subseq.block_infos = self.block_infos[ix]
  39. subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
  40. subseq.last_update_time = self.last_update_time
  41. return subseq
  42. def update_(self):
  43. with self.lock_changes:
  44. self.update_block_infos_()
  45. self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
  46. def update_block_infos_(self):
  47. new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
  48. assert len(new_block_infos) == len(self.block_uids)
  49. for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
  50. if info is None:
  51. logger.warning(f"Found no block info for block {uid}")
  52. if not isinstance(info, RemoteModuleInfo):
  53. logger.warning(f"Unexpected dht entry type for {uid}: {info}")
  54. if not info.servers:
  55. logger.warning(f"Found no active peers for block {uid}")
  56. if info.uid != uid:
  57. logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
  58. self.block_infos[block_index] = info
  59. @staticmethod
  60. def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
  61. closed_spans = []
  62. active_spans = {}
  63. for block_index, info in enumerate(block_infos):
  64. for peer_id, server in info.servers.items():
  65. if server.state != ServerState.ONLINE:
  66. continue
  67. if peer_id not in active_spans:
  68. active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
  69. else: # peer_id in active_spans
  70. active_spans[peer_id].end = block_index + 1
  71. for peer_id in list(active_spans.keys()):
  72. if (
  73. peer_id not in info.servers
  74. or info.servers[peer_id].state != ServerState.ONLINE
  75. or block_index == len(block_infos) - 1
  76. ):
  77. closed_spans.append(active_spans.pop(peer_id))
  78. assert not active_spans
  79. closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
  80. spans_containing_block = tuple(list() for _ in range(len(block_infos)))
  81. for span in closed_spans:
  82. for block_index in range(span.start, span.end):
  83. spans_containing_block[block_index].append(span)
  84. return closed_spans, spans_containing_block
  85. def __len__(self):
  86. return len(self.block_uids)