protocol.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. """ RPC protocol that provides nodes a way to communicate with each other """
  2. from __future__ import annotations
  3. import asyncio
  4. from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union
  5. from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
  6. from hivemind.dht.routing import DHTID, BinaryDHTValue, RoutingTable, Subkey
  7. from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
  8. from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
  9. from hivemind.proto import dht_pb2
  10. from hivemind.utils import MSGPackSerializer, get_logger
  11. from hivemind.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapper
  12. from hivemind.utils.timed_storage import (
  13. MAX_DHT_TIME_DISCREPANCY_SECONDS,
  14. DHTExpiration,
  15. ValueWithExpiration,
  16. get_dht_time,
  17. )
  18. logger = get_logger(__name__)
  19. class DHTProtocol(ServicerBase):
  20. # fmt:off
  21. p2p: P2P
  22. node_id: DHTID; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
  23. storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
  24. record_validator: Optional[RecordValidatorBase]
  25. # fmt:on
  26. serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
  27. RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b""
  28. @classmethod
  29. async def create(
  30. cls,
  31. p2p: P2P,
  32. node_id: DHTID,
  33. bucket_size: int,
  34. depth_modulo: int,
  35. num_replicas: int,
  36. wait_timeout: float,
  37. parallel_rpc: Optional[int] = None,
  38. cache_size: Optional[int] = None,
  39. client_mode: bool = False,
  40. record_validator: Optional[RecordValidatorBase] = None,
  41. authorizer: Optional[AuthorizerBase] = None,
  42. ) -> DHTProtocol:
  43. """
  44. A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
  45. As a side-effect, DHTProtocol also maintains a routing table as described in
  46. https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
  47. See DHTNode (node.py) for a more detailed description.
  48. :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
  49. for instance, def rpc_ping can be called as protocol.call_ping(peer_id, dht_id) from a remote machine
  50. Only the call_* methods are meant to be called publicly, e.g. from DHTNode
  51. Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
  52. """
  53. self = cls(_initialized_with_create=True)
  54. self.p2p = p2p
  55. self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
  56. self.wait_timeout = wait_timeout
  57. self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
  58. self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
  59. self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float("inf"))
  60. self.client_mode = client_mode
  61. self.record_validator = record_validator
  62. self.authorizer = authorizer
  63. if not client_mode:
  64. await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
  65. self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
  66. else:
  67. # note: use empty node_info so peers won't add you to their routing tables
  68. self.node_info = dht_pb2.NodeInfo()
  69. return self
  70. def __init__(self, *, _initialized_with_create=False):
  71. """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
  72. assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
  73. super().__init__()
  74. def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
  75. """get a stub that sends requests to a given peer"""
  76. stub = super().get_stub(self.p2p, peer)
  77. return AuthRPCWrapper(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
  78. async def call_ping(self, peer: PeerID, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
  79. """
  80. Get peer's node id and add him to the routing table. If peer doesn't respond, return None
  81. :param peer: peer ID to ping
  82. :param validate: if True, validates that node's peer_id is available
  83. :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
  84. :note: if DHTProtocol was created with client_mode=False, also request peer to add you to his routing table
  85. :return: node's DHTID, if peer responded and decided to send his node_id
  86. """
  87. try:
  88. async with self.rpc_semaphore:
  89. ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate)
  90. time_requested = get_dht_time()
  91. response = await self.get_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
  92. time_responded = get_dht_time()
  93. except Exception as e:
  94. logger.debug(f"DHTProtocol failed to ping {peer}", exc_info=True)
  95. response = None
  96. responded = bool(response and response.peer and response.peer.node_id)
  97. if responded and validate:
  98. try:
  99. if not self.client_mode and not response.available:
  100. raise ValidationError(
  101. f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
  102. )
  103. if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
  104. if (
  105. response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS
  106. or response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS
  107. ):
  108. raise ValidationError(
  109. f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
  110. f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})"
  111. )
  112. except ValidationError as e:
  113. if strict:
  114. raise
  115. else:
  116. logger.warning(repr(e))
  117. peer_id = DHTID.from_bytes(response.peer.node_id) if responded else None
  118. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
  119. return peer_id
  120. async def rpc_ping(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
  121. """Some node wants us to add it to our routing table."""
  122. response = dht_pb2.PingResponse(peer=self.node_info, dht_time=get_dht_time(), available=False)
  123. if request.peer and request.peer.node_id:
  124. sender_id = DHTID.from_bytes(request.peer.node_id)
  125. sender_peer_id = context.remote_id
  126. if request.validate:
  127. response.available = await self.call_ping(sender_peer_id, validate=False) == sender_id
  128. asyncio.create_task(
  129. self.update_routing_table(
  130. sender_id, sender_peer_id, responded=response.available or not request.validate
  131. )
  132. )
  133. return response
  134. async def call_store(
  135. self,
  136. peer: PeerID,
  137. keys: Sequence[DHTID],
  138. values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
  139. expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
  140. subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
  141. in_cache: Optional[Union[bool, Sequence[bool]]] = None,
  142. ) -> Optional[List[bool]]:
  143. """
  144. Ask a recipient to store several (key, value : expiration_time) items or update their older value
  145. :param peer: request this peer to store the data
  146. :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
  147. :param values: a list of N serialized values (bytes) for each respective key
  148. :param expiration_time: a list of N expiration timestamps for each respective key-value pair(see get_dht_time())
  149. :param subkeys: a list of N optional sub-keys. If None, stores value normally. If not subkey is not None:
  150. 1) if local storage doesn't have :key:, create a new dictionary {subkey: (value, expiration_time)}
  151. 2) if local storage already has a dictionary under :key:, try add (subkey, value, exp_time) to that dictionary
  152. 2) if local storage associates :key: with a normal value with smaller expiration, clear :key: and perform (1)
  153. 3) finally, if local storage currently associates :key: with a normal value with larger expiration, do nothing
  154. :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
  155. :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
  156. until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
  157. :return: list of [True / False] True = stored, False = failed (found newer value or no response)
  158. if peer did not respond (e.g. due to timeout or congestion), returns None
  159. """
  160. if isinstance(expiration_time, DHTExpiration):
  161. expiration_time = [expiration_time] * len(keys)
  162. if subkeys is None:
  163. subkeys = [None] * len(keys)
  164. in_cache = in_cache if in_cache is not None else [False] * len(keys) # default value (None)
  165. in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache # single bool
  166. keys, subkeys, values, expiration_time, in_cache = map(
  167. list, [keys, subkeys, values, expiration_time, in_cache]
  168. )
  169. for i in range(len(keys)):
  170. if subkeys[i] is None: # add default sub-key if not specified
  171. subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE
  172. else:
  173. subkeys[i] = self.serializer.dumps(subkeys[i])
  174. if isinstance(values[i], DictionaryDHTValue):
  175. assert (
  176. subkeys[i] == self.IS_DICTIONARY
  177. ), "Please don't specify subkey when storing an entire dictionary"
  178. values[i] = self.serializer.dumps(values[i])
  179. assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
  180. store_request = dht_pb2.StoreRequest(
  181. keys=list(map(DHTID.to_bytes, keys)),
  182. subkeys=subkeys,
  183. values=values,
  184. expiration_time=expiration_time,
  185. in_cache=in_cache,
  186. peer=self.node_info,
  187. )
  188. try:
  189. async with self.rpc_semaphore:
  190. response = await self.get_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
  191. if response.peer and response.peer.node_id:
  192. peer_id = DHTID.from_bytes(response.peer.node_id)
  193. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
  194. return response.store_ok
  195. except Exception as e:
  196. logger.debug(f"DHTProtocol failed to store at {peer}", exc_info=True)
  197. asyncio.create_task(self.update_routing_table(self.routing_table.get(peer_id=peer), peer, responded=False))
  198. return None
  199. async def rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
  200. """Some node wants us to store this (key, value) pair"""
  201. if request.peer: # if requested, add peer to the routing table
  202. asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
  203. assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
  204. response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
  205. for key, tag, value_bytes, expiration_time, in_cache in zip(
  206. request.keys, request.subkeys, request.values, request.expiration_time, request.in_cache
  207. ):
  208. key_id = DHTID.from_bytes(key)
  209. storage = self.cache if in_cache else self.storage
  210. if tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys
  211. value_dictionary = self.serializer.loads(value_bytes)
  212. assert isinstance(value_dictionary, DictionaryDHTValue)
  213. if not self._validate_dictionary(key, value_dictionary):
  214. response.store_ok.append(False)
  215. continue
  216. response.store_ok.append(
  217. all(
  218. storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
  219. for subkey, item in value_dictionary.items()
  220. )
  221. )
  222. continue
  223. if not self._validate_record(key, tag, value_bytes, expiration_time):
  224. response.store_ok.append(False)
  225. continue
  226. if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys
  227. response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
  228. else: # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
  229. subkey = self.serializer.loads(tag)
  230. response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
  231. return response
  232. async def call_find(
  233. self, peer: PeerID, keys: Collection[DHTID]
  234. ) -> Optional[
  235. Dict[
  236. DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, PeerID]]
  237. ]
  238. ]:
  239. """
  240. Request keys from a peer. For each key, look for its (value, expiration time) locally and
  241. k additional peers that are most likely to have this key (ranked by XOR distance)
  242. :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
  243. value: value stored by the recipient with that key, or None if peer doesn't have this value
  244. expiration time: expiration time of the returned value, None if no value was found
  245. neighbors: a dictionary[node_id : peer_id] containing nearest neighbors from peer's routing table
  246. If peer didn't respond, returns None
  247. """
  248. keys = list(keys)
  249. find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
  250. try:
  251. async with self.rpc_semaphore:
  252. response = await self.get_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
  253. if response.peer and response.peer.node_id:
  254. peer_id = DHTID.from_bytes(response.peer.node_id)
  255. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
  256. assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys"
  257. output = {} # unpack data depending on its type
  258. for key, result in zip(keys, response.results):
  259. key_bytes = DHTID.to_bytes(key)
  260. nearest = dict(
  261. zip(
  262. map(DHTID.from_bytes, result.nearest_node_ids),
  263. map(PeerID, result.nearest_peer_ids),
  264. )
  265. )
  266. if result.type == dht_pb2.NOT_FOUND:
  267. output[key] = None, nearest
  268. elif result.type == dht_pb2.FOUND_REGULAR:
  269. if not self._validate_record(
  270. key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time
  271. ):
  272. output[key] = None, nearest
  273. continue
  274. output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
  275. elif result.type == dht_pb2.FOUND_DICTIONARY:
  276. value_dictionary = self.serializer.loads(result.value)
  277. if not self._validate_dictionary(key_bytes, value_dictionary):
  278. output[key] = None, nearest
  279. continue
  280. output[key] = ValueWithExpiration(value_dictionary, result.expiration_time), nearest
  281. else:
  282. logger.error(f"Unknown result type: {result.type}")
  283. return output
  284. except Exception as e:
  285. logger.debug(f"DHTProtocol failed to find at {peer}", exc_info=True)
  286. asyncio.create_task(self.update_routing_table(self.routing_table.get(peer_id=peer), peer, responded=False))
  287. async def rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
  288. """
  289. Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
  290. Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
  291. """
  292. if request.peer: # if requested, add peer to the routing table
  293. asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
  294. response = dht_pb2.FindResponse(results=[], peer=self.node_info)
  295. for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
  296. maybe_item = self.storage.get(key_id)
  297. cached_item = self.cache.get(key_id)
  298. if cached_item is not None and (
  299. maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time
  300. ):
  301. maybe_item = cached_item
  302. if maybe_item is None: # value not found
  303. item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
  304. elif isinstance(maybe_item.value, DictionaryDHTValue):
  305. item = dht_pb2.FindResult(
  306. type=dht_pb2.FOUND_DICTIONARY,
  307. value=self.serializer.dumps(maybe_item.value),
  308. expiration_time=maybe_item.expiration_time,
  309. )
  310. else: # found regular value
  311. item = dht_pb2.FindResult(
  312. type=dht_pb2.FOUND_REGULAR, value=maybe_item.value, expiration_time=maybe_item.expiration_time
  313. )
  314. for node_id, peer_id in self.routing_table.get_nearest_neighbors(
  315. key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
  316. ):
  317. item.nearest_node_ids.append(node_id.to_bytes())
  318. item.nearest_peer_ids.append(peer_id.to_bytes())
  319. response.results.append(item)
  320. return response
  321. async def update_routing_table(self, node_id: Optional[DHTID], peer_id: PeerID, responded=True):
  322. """
  323. This method is called on every incoming AND outgoing request to update the routing table
  324. :param peer_id: sender peer_id for incoming requests, recipient peer_id for outgoing requests
  325. :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
  326. :param responded: for outgoing requests, this indicated whether recipient responded or not.
  327. For incoming requests, this should always be True
  328. """
  329. node_id = node_id if node_id is not None else self.routing_table.get(peer_id=peer_id)
  330. if responded: # incoming request or outgoing request with response
  331. if node_id not in self.routing_table:
  332. # we just met a new node, maybe we know some values that it *should* store
  333. data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
  334. for key, item in list(self.storage.items()):
  335. neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
  336. if neighbors:
  337. nearest_distance = neighbors[0][0].xor_distance(key)
  338. farthest_distance = neighbors[-1][0].xor_distance(key)
  339. new_node_should_store = node_id.xor_distance(key) < farthest_distance
  340. this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
  341. if not neighbors or (new_node_should_store and this_node_is_responsible):
  342. data_to_send.append((key, item.value, item.expiration_time))
  343. if data_to_send:
  344. asyncio.create_task(self.call_store(peer_id, *zip(*data_to_send), in_cache=False))
  345. maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_id)
  346. if maybe_node_to_ping is not None:
  347. # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
  348. # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
  349. asyncio.create_task(self.call_ping(maybe_node_to_ping[1])) # [1]-th element is that node's peer_id
  350. else: # we sent outgoing request and peer did not respond
  351. if node_id is not None and node_id in self.routing_table:
  352. del self.routing_table[node_id]
  353. def _validate_record(
  354. self, key_bytes: bytes, subkey_bytes: bytes, value_bytes: bytes, expiration_time: float
  355. ) -> bool:
  356. if self.record_validator is None:
  357. return True
  358. record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
  359. return self.record_validator.validate(record)
  360. def _validate_dictionary(self, key_bytes: bytes, dictionary: DictionaryDHTValue) -> bool:
  361. if self.record_validator is None:
  362. return True
  363. with dictionary.freeze():
  364. for subkey, (value_bytes, expiration_time) in dictionary.items():
  365. subkey_bytes = self.serializer.dumps(subkey)
  366. record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
  367. if not self.record_validator.validate(record):
  368. return False
  369. return True
  370. class ValidationError(Exception):
  371. """This exception is thrown if DHT node didn't pass validation by other nodes."""