__init__.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. This is a Distributed Hash Table optimized for rapidly accessing a lot of lightweight metadata.
  3. Hivemind DHT is based on Kademlia [1] with added support for improved bulk store/get operations and caching.
  4. The code is organized as follows:
  5. * **class DHT (__init__.py)** - high-level class for model training. Runs DHTNode in a background process.
  6. * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
  7. * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
  8. * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
  9. - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
  10. - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
  11. """
  12. from __future__ import annotations
  13. import asyncio
  14. import multiprocessing as mp
  15. import os
  16. from concurrent.futures import ThreadPoolExecutor
  17. from functools import partial
  18. from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
  19. from multiaddr import Multiaddr
  20. from hivemind.dht.node import DHTID, DHTNode
  21. from hivemind.dht.routing import DHTKey, DHTValue, Subkey
  22. from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
  23. from hivemind.p2p import P2P
  24. from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
  25. logger = get_logger(__name__)
  26. ReturnType = TypeVar("ReturnType")
  27. class DHT(mp.Process):
  28. """
  29. A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
  30. * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
  31. * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
  32. :param p2p: instance of hivemind.p2p.P2P that will be used for communication.
  33. If None, DHTNode will create and manage its own P2P instance with given initial_peers and
  34. parameters from ``kwargs``
  35. :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
  36. :param start: if True, automatically starts the background process on creation. Otherwise await manual start
  37. :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
  38. :param max_workers: declare_experts and get_experts will use up to this many parallel workers
  39. (but no more than one per key)
  40. :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
  41. :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
  42. The validators will be combined using the CompositeValidator class. It merges them when possible
  43. (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
  44. :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
  45. :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
  46. """
  47. _node: DHTNode
  48. def __init__(
  49. self,
  50. p2p: Optional[P2P] = None,
  51. initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
  52. *,
  53. start: bool,
  54. daemon: bool = True,
  55. max_workers: Optional[int] = None,
  56. record_validators: Iterable[RecordValidatorBase] = (),
  57. shutdown_timeout: float = 3,
  58. **kwargs,
  59. ):
  60. super().__init__()
  61. self.p2p = p2p
  62. if not (
  63. initial_peers is None
  64. or (
  65. isinstance(initial_peers, Sequence)
  66. and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
  67. )
  68. ):
  69. raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
  70. self.initial_peers = initial_peers
  71. self.kwargs = kwargs
  72. self.max_workers = max_workers
  73. self._record_validator = CompositeValidator(record_validators)
  74. self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
  75. self.shutdown_timeout = shutdown_timeout
  76. self.ready = mp.Event()
  77. self.daemon = daemon
  78. if start:
  79. self.run_in_background(await_ready=True)
  80. def run(self) -> None:
  81. """Serve DHT forever. This function will not return until DHT node is shut down"""
  82. loop = switch_to_uvloop()
  83. with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
  84. async def _run():
  85. self._node = await DHTNode.create(
  86. p2p=self.p2p,
  87. initial_peers=self.initial_peers,
  88. num_workers=self.max_workers or 1,
  89. record_validator=self._record_validator,
  90. **self.kwargs,
  91. )
  92. self.ready.set()
  93. while True:
  94. method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
  95. task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
  96. if method == "_shutdown":
  97. await task
  98. break
  99. coro = _run()
  100. loop.run_until_complete(coro)
  101. def run_in_background(self, await_ready=True, timeout=None):
  102. """
  103. Starts DHT in a background process. if await_ready, this method will wait until background dht
  104. is ready to process incoming requests or for :timeout: seconds max.
  105. """
  106. self.start()
  107. if await_ready and not self.ready.wait(timeout=timeout):
  108. raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
  109. def shutdown(self) -> None:
  110. """Shut down a running dht process"""
  111. if self.is_alive():
  112. self._outer_pipe.send(("_shutdown", [], {}))
  113. self.join(self.shutdown_timeout)
  114. if self.is_alive():
  115. logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
  116. self.terminate()
  117. async def _shutdown(self):
  118. await self._node.shutdown()
  119. def get(
  120. self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
  121. ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
  122. """
  123. Search for a key across DHT and return either first or latest entry (if found).
  124. :param key: same key as in node.store(...)
  125. :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
  126. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  127. :param kwargs: parameters forwarded to DHTNode.get_many_by_id
  128. :returns: (value, expiration time); if value was not found, returns None
  129. """
  130. future = MPFuture()
  131. self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
  132. return future if return_future else future.result()
  133. async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
  134. try:
  135. result = await self._node.get(key, latest=latest, **kwargs)
  136. if not future.done():
  137. future.set_result(result)
  138. except BaseException as e:
  139. if not future.done():
  140. future.set_exception(e)
  141. raise
  142. def store(
  143. self,
  144. key: DHTKey,
  145. value: DHTValue,
  146. expiration_time: DHTExpiration,
  147. subkey: Optional[Subkey] = None,
  148. return_future: bool = False,
  149. **kwargs,
  150. ) -> Union[bool, MPFuture]:
  151. """
  152. Find num_replicas best nodes to store (key, value) and store it there until expiration time.
  153. :param key: msgpack-serializable key to be associated with value until expiration.
  154. :param value: msgpack-serializable value to be stored under a given key until expiration.
  155. :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
  156. :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
  157. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  158. :returns: True if store succeeds, False if it fails (due to no response or newer value)
  159. """
  160. future = MPFuture()
  161. self._outer_pipe.send(
  162. (
  163. "_store",
  164. [],
  165. dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
  166. )
  167. )
  168. return future if return_future else future.result()
  169. async def _store(
  170. self,
  171. key: DHTKey,
  172. value: DHTValue,
  173. expiration_time: DHTExpiration,
  174. subkey: Optional[Subkey],
  175. future: MPFuture,
  176. **kwargs,
  177. ):
  178. try:
  179. result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
  180. if not future.done():
  181. future.set_result(result)
  182. except BaseException as e:
  183. if not future.done():
  184. future.set_exception(e)
  185. raise
  186. def run_coroutine(
  187. self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
  188. ) -> Union[ReturnType, MPFuture[ReturnType]]:
  189. """
  190. Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
  191. for running custom functions DHT for special cases (e.g. declare experts, beam search)
  192. :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
  193. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  194. :returns: coroutine outputs or MPFuture for these outputs
  195. :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
  196. DHT fields made by this coroutine will not be accessible from the host process.
  197. :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
  198. or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
  199. :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
  200. """
  201. future = MPFuture()
  202. self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
  203. return future if return_future else future.result()
  204. async def _run_coroutine(
  205. self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
  206. ):
  207. main_task = asyncio.create_task(coro(self, self._node))
  208. cancel_task = asyncio.create_task(await_cancelled(future))
  209. try:
  210. await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
  211. if future.cancelled():
  212. main_task.cancel()
  213. else:
  214. future.set_result(await main_task)
  215. except BaseException as e:
  216. logger.exception(f"Caught an exception when running a coroutine: {e}")
  217. if not future.done():
  218. future.set_exception(e)
  219. def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
  220. if not self.ready.is_set():
  221. raise RuntimeError(
  222. "Can't append new validators before the DHT process has started. "
  223. "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
  224. )
  225. self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
  226. async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
  227. node.protocol.record_validator.extend(record_validators)
  228. def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
  229. """
  230. Get multiaddrs of the current DHT node that should be accessible by other peers.
  231. :param latest: ask the P2P daemon to refresh the visible multiaddrs
  232. """
  233. return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
  234. async def _get_visible_maddrs(self, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
  235. return await node.get_visible_maddrs(latest=latest)
  236. def __del__(self):
  237. if self._parent_pid == os.getpid() and self.is_alive():
  238. self.shutdown()