dht_handler.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import threading
  2. from functools import partial
  3. from typing import Dict, List, Optional, Sequence, Tuple, Union
  4. from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
  5. from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
  6. from hivemind.moe.expert_uid import (
  7. FLAT_EXPERT,
  8. UID_DELIMITER,
  9. UID_PATTERN,
  10. Coordinate,
  11. ExpertInfo,
  12. ExpertPrefix,
  13. ExpertUID,
  14. is_valid_uid,
  15. split_uid,
  16. )
  17. from hivemind.p2p import PeerID
  18. from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_time
  19. class DHTHandlerThread(threading.Thread):
  20. def __init__(
  21. self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
  22. ):
  23. super().__init__(**kwargs)
  24. if expiration is None:
  25. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  26. self.module_backends = module_backends
  27. self.dht = dht
  28. self.update_period = update_period
  29. self.expiration = expiration
  30. self.stop = threading.Event()
  31. def run(self) -> None:
  32. declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
  33. while not self.stop.wait(self.update_period):
  34. declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
  35. def declare_experts(
  36. dht: DHT, uids: Sequence[ExpertUID], expiration_time: DHTExpiration, wait: bool = True
  37. ) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
  38. """
  39. Make experts visible to all DHT peers; update timestamps if declared previously.
  40. :param uids: a list of expert ids to update
  41. :param wait: if True, awaits for declaration to finish, otherwise runs in background
  42. :param expiration_time: experts will be visible for this many seconds
  43. :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
  44. """
  45. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  46. if not isinstance(uids, list):
  47. uids = list(uids)
  48. for uid in uids:
  49. assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
  50. return dht.run_coroutine(
  51. partial(_declare_experts, uids=uids, expiration_time=expiration_time), return_future=not wait
  52. )
  53. async def _declare_experts(
  54. dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: DHTExpiration
  55. ) -> Dict[ExpertUID, bool]:
  56. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  57. data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
  58. peer_id_base58 = dht.peer_id.to_base58()
  59. for uid in uids:
  60. data_to_store[uid, None] = peer_id_base58
  61. prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
  62. for i in range(prefix.count(UID_DELIMITER) - 1):
  63. prefix, last_coord = split_uid(prefix)
  64. data_to_store[prefix, last_coord] = (uid, peer_id_base58)
  65. keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
  66. store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
  67. return store_ok
  68. def get_experts(
  69. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  70. ) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]:
  71. """
  72. :param uids: find experts with these ids from across the DHT
  73. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  74. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  75. :returns: a list of [RemoteExpert if found else None]
  76. """
  77. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  78. result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
  79. return create_remote_experts(result, dht, return_future)
  80. async def _get_experts(
  81. dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
  82. ) -> List[Optional[ExpertInfo]]:
  83. if expiration_time is None:
  84. expiration_time = get_dht_time()
  85. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  86. found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
  87. experts: List[Optional[ExpertInfo]] = [None] * len(uids)
  88. for i, uid in enumerate(uids):
  89. server_peer_id = found[uid]
  90. if server_peer_id is not None and isinstance(server_peer_id.value, str):
  91. experts[i] = ExpertInfo(uid, PeerID.from_base58(server_peer_id.value))
  92. return experts