dht_handler.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import threading
  2. from functools import partial
  3. from typing import Dict, List, Optional, Sequence, Tuple, Union
  4. from multiaddr import Multiaddr
  5. from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
  6. from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
  7. from hivemind.moe.server.expert_uid import (
  8. FLAT_EXPERT,
  9. UID_DELIMITER,
  10. UID_PATTERN,
  11. Coordinate,
  12. ExpertPrefix,
  13. ExpertUID,
  14. is_valid_uid,
  15. split_uid,
  16. )
  17. from hivemind.p2p import PeerID, PeerInfo
  18. from hivemind.utils import get_dht_time, LazyFutureCaller, LazyValue
  19. class DHTHandlerThread(threading.Thread):
  20. def __init__(self, experts, dht: DHT, peer_id: PeerID, update_period: int = 5, **kwargs):
  21. super().__init__(**kwargs)
  22. self.peer_id = peer_id
  23. self.experts = experts
  24. self.dht = dht
  25. self.update_period = update_period
  26. self.stop = threading.Event()
  27. def run(self) -> None:
  28. declare_experts(self.dht, self.experts.keys(), self.peer_id)
  29. while not self.stop.wait(self.update_period):
  30. declare_experts(self.dht, self.experts.keys(), self.peer_id)
  31. def declare_experts(
  32. dht: DHT, uids: Sequence[ExpertUID], peer_id: PeerID, expiration: DHTExpiration = 300, wait: bool = True
  33. ) -> Dict[ExpertUID, bool]:
  34. """
  35. Make experts visible to all DHT peers; update timestamps if declared previously.
  36. :param uids: a list of expert ids to update
  37. :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
  38. :param wait: if True, awaits for declaration to finish, otherwise runs in background
  39. :param expiration: experts will be visible for this many seconds
  40. :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
  41. """
  42. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  43. for uid in uids:
  44. assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
  45. addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
  46. return dht.run_coroutine(
  47. partial(_declare_experts, uids=list(uids), peer_id=peer_id, addrs=addrs, expiration=expiration),
  48. return_future=not wait,
  49. )
  50. async def _declare_experts(
  51. dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
  52. ) -> Dict[ExpertUID, bool]:
  53. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  54. expiration_time = get_dht_time() + expiration
  55. data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
  56. for uid in uids:
  57. data_to_store[uid, None] = (peer_id.to_base58(), addrs)
  58. prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
  59. for i in range(prefix.count(UID_DELIMITER) - 1):
  60. prefix, last_coord = split_uid(prefix)
  61. data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
  62. keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
  63. store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
  64. return store_ok
  65. def get_experts(
  66. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  67. ) -> Union[List[Optional[RemoteExpert]], LazyFutureCaller[Optional[LazyValue[RemoteExpert]], Optional[RemoteExpert]]]:
  68. """
  69. :param uids: find experts with these ids from across the DHT
  70. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  71. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  72. :returns: a list of [RemoteExpert if found else None]
  73. """
  74. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  75. p2p = _RemoteModuleCall.run_coroutine(dht.replicate_p2p())
  76. result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
  77. def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
  78. return [val.get(p2p=p2p) if val is not None else None for val in vals]
  79. if return_future:
  80. return LazyFutureCaller(result, _unwrap_experts)
  81. return _unwrap_experts(result)
  82. async def _get_experts(
  83. dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
  84. ) -> List[Optional[LazyValue[RemoteExpert]]]:
  85. if expiration_time is None:
  86. expiration_time = get_dht_time()
  87. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  88. found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
  89. experts: List[Optional[RemoteExpert]] = [None] * len(uids)
  90. for i, uid in enumerate(uids):
  91. elem = found[uid]
  92. if elem is not None and isinstance(elem.value, tuple):
  93. peer_id, addrs = elem.value
  94. peer_info = PeerInfo(peer_id=PeerID.from_base58(peer_id), addrs=tuple(Multiaddr(a) for a in addrs))
  95. experts[i] = LazyValue(init=partial(RemoteExpert, uid=uid, server_peer_info=peer_info))
  96. return experts