dht_handler.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import threading
  2. from functools import partial
  3. from typing import Dict, List, Optional, Sequence, Tuple, Union
  4. from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
  5. from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
  6. from hivemind.moe.server.expert_uid import (
  7. FLAT_EXPERT,
  8. UID_DELIMITER,
  9. UID_PATTERN,
  10. Coordinate,
  11. ExpertPrefix,
  12. ExpertUID,
  13. is_valid_uid,
  14. split_uid,
  15. )
  16. from hivemind.p2p import PeerID, PeerInfo
  17. from hivemind.utils import MPFuture, get_dht_time
  18. class DHTHandlerThread(threading.Thread):
  19. def __init__(self, experts, dht: DHT, peer_id: PeerID, update_period: int = 5, **kwargs):
  20. super().__init__(**kwargs)
  21. self.peer_id = peer_id
  22. self.experts = experts
  23. self.dht = dht
  24. self.update_period = update_period
  25. self.stop = threading.Event()
  26. def run(self) -> None:
  27. declare_experts(self.dht, self.experts.keys(), self.peer_id)
  28. while not self.stop.wait(self.update_period):
  29. declare_experts(self.dht, self.experts.keys(), self.peer_id)
  30. def declare_experts(
  31. dht: DHT, uids: Sequence[ExpertUID], peer_id: PeerID, expiration: DHTExpiration = 300, wait: bool = True
  32. ) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
  33. """
  34. Make experts visible to all DHT peers; update timestamps if declared previously.
  35. :param uids: a list of expert ids to update
  36. :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
  37. :param wait: if True, awaits for declaration to finish, otherwise runs in background
  38. :param expiration: experts will be visible for this many seconds
  39. :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
  40. """
  41. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  42. for uid in uids:
  43. assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
  44. addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
  45. return dht.run_coroutine(
  46. partial(_declare_experts, uids=list(uids), peer_id=peer_id, addrs=addrs, expiration=expiration),
  47. return_future=not wait,
  48. )
  49. async def _declare_experts(
  50. dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
  51. ) -> Dict[ExpertUID, bool]:
  52. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  53. expiration_time = get_dht_time() + expiration
  54. data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
  55. for uid in uids:
  56. data_to_store[uid, None] = (peer_id.to_base58(), addrs)
  57. prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
  58. for i in range(prefix.count(UID_DELIMITER) - 1):
  59. prefix, last_coord = split_uid(prefix)
  60. data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
  61. keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
  62. store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
  63. return store_ok
  64. def get_experts(
  65. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  66. ) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]:
  67. """
  68. :param uids: find experts with these ids from across the DHT
  69. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  70. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  71. :returns: a list of [RemoteExpert if found else None]
  72. """
  73. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  74. result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
  75. return create_remote_experts(result, dht, return_future)
  76. async def _get_experts(
  77. dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
  78. ) -> List[Optional[RemoteExpertInfo]]:
  79. if expiration_time is None:
  80. expiration_time = get_dht_time()
  81. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  82. found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
  83. experts: List[Optional[RemoteExpert]] = [None] * len(uids)
  84. for i, uid in enumerate(uids):
  85. elem = found[uid]
  86. if elem is not None and isinstance(elem.value, tuple):
  87. experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(elem.value))
  88. return experts