remote_sequence_info.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from __future__ import annotations
  2. import dataclasses
  3. import threading
  4. from functools import partial
  5. from typing import Tuple, List, Optional, Sequence, NamedTuple
  6. from hivemind import DHT, PeerID
  7. from hivemind.utils.logging import use_hivemind_log_handler, get_logger
  8. from src.data_structures import ModuleUID, RemoteModuleInfo
  9. from src.dht_utils import _get_remote_module_infos
  10. use_hivemind_log_handler("in_root_logger")
  11. logger = get_logger(__file__)
  12. Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
  13. @dataclasses.dataclass(frozen=False, init=False)
  14. class RemoteSequenceInfo:
  15. """Keeps and updates the meta-information about which peers host which blocks"""
  16. dht: DHT
  17. block_uids: List[ModuleUID, ...]
  18. block_infos: List[Optional[RemoteModuleInfo], ...]
  19. spans_by_priority: List[Span] # sorted from best to worst
  20. spans_containing_block: Tuple[List[Span], ...]
  21. lock_changes: threading.Lock
  22. def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
  23. self.dht = dht
  24. self.block_uids = list(block_uids)
  25. self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
  26. self.spans_by_priority = []
  27. self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
  28. self.lock_changes = threading.Lock()
  29. self.update_()
  30. for uid, info in zip(self.block_uids, self.block_infos):
  31. assert info is not None, f"Found no remote peers for block {uid}"
  32. assert self.spans_by_priority and self.spans_containing_block
  33. def update_(self):
  34. with self.lock_changes:
  35. self.update_block_infos_()
  36. self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
  37. def update_block_infos_(self):
  38. new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
  39. partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
  40. return_future=False)
  41. assert len(new_block_infos) == len(self.block_uids)
  42. for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
  43. if info is None:
  44. logger.warning(f"Found no block info for block {uid}")
  45. if not isinstance(info, RemoteModuleInfo):
  46. logger.warning(f"Unexpected dht entry type for {uid}: {info}")
  47. if not info.peer_ids:
  48. logger.warning(f"Found no active peers for block {uid}")
  49. if info.uid != uid:
  50. logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
  51. if not isinstance(info.peer_ids, set):
  52. logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
  53. self.block_infos[block_index] = info
  54. @staticmethod
  55. def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
  56. closed_spans = []
  57. active_spans = {}
  58. for block_index, info in enumerate(block_infos):
  59. for peer_id in info.peer_ids:
  60. if peer_id not in active_spans:
  61. active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
  62. else: # peer_id in active_spans
  63. active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
  64. for peer_id in list(active_spans.keys()):
  65. if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
  66. closed_spans.append(active_spans.pop(peer_id))
  67. assert not active_spans
  68. closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
  69. spans_containing_block = tuple(list() for _ in range(len(block_infos)))
  70. for span in closed_spans:
  71. for block_index in range(span.start, span.end):
  72. spans_containing_block[block_index].append(span)
  73. return closed_spans, spans_containing_block
  74. def __len__(self):
  75. return len(self.block_uids)