protocol.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. """ RPC protocol that provides nodes a way to communicate with each other. Based on gRPC.AIO. """
  2. from __future__ import annotations
  3. import asyncio
  4. from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
  5. import grpc
  6. from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
  7. from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
  8. from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
  9. from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
  10. from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
  11. logger = get_logger(__name__)
  12. class DHTProtocol(dht_grpc.DHTServicer):
  13. # fmt:off
  14. node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
  15. channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
  16. storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
  17. # fmt:on
  18. serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
  19. RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
  20. @classmethod
  21. async def create(
  22. cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
  23. parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
  24. listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
  25. channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
  26. """
  27. A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
  28. As a side-effect, DHTProtocol also maintains a routing table as described in
  29. https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
  30. See DHTNode (node.py) for a more detailed description.
  31. :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
  32. for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
  33. Only the call_* methods are meant to be called publicly, e.g. from DHTNode
  34. Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
  35. """
  36. self = cls(_initialized_with_create=True)
  37. self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
  38. self.wait_timeout, self.channel_options = wait_timeout, tuple(channel_options)
  39. self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
  40. self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
  41. self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
  42. if listen: # set up server to process incoming rpc requests
  43. grpc.aio.init_grpc_aio()
  44. self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
  45. dht_grpc.add_DHTServicer_to_server(self, self.server)
  46. self.port = self.server.add_insecure_port(listen_on)
  47. assert self.port != 0, f"Failed to listen to {listen_on}"
  48. if endpoint is not None and endpoint.endswith('*'):
  49. endpoint = replace_port(endpoint, self.port)
  50. self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
  51. endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
  52. await self.server.start()
  53. else: # not listening to incoming requests, client-only mode
  54. # note: use empty node_info so peers won't add you to their routing tables
  55. self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
  56. if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
  57. logger.warning(f"DHTProtocol has no server (due to listen=False), listen_on"
  58. f"and kwargs have no effect (unused kwargs: {kwargs})")
  59. return self
  60. def __init__(self, *, _initialized_with_create=False):
  61. """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
  62. assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
  63. super().__init__()
  64. async def shutdown(self, timeout=None):
  65. """ Process existing requests, close all connections and stop the server """
  66. if self.server:
  67. await self.server.stop(timeout)
  68. else:
  69. logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
  70. def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
  71. """ get a DHTStub that sends requests to a given peer """
  72. return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
  73. async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
  74. """
  75. Get peer's node id and add him to the routing table. If peer doesn't respond, return None
  76. :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
  77. :param validate: if True, validates that node's endpoint is available
  78. :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
  79. :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
  80. :return: node's DHTID, if peer responded and decided to send his node_id
  81. """
  82. try:
  83. async with self.rpc_semaphore:
  84. ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate)
  85. time_requested = get_dht_time()
  86. response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
  87. time_responded = get_dht_time()
  88. except grpc.aio.AioRpcError as error:
  89. logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
  90. response = None
  91. responded = bool(response and response.peer and response.peer.node_id)
  92. if responded and validate:
  93. try:
  94. if self.server is not None and not response.available:
  95. raise ValidationError(f"Peer {peer} couldn't access this node at {response.sender_endpoint} . "
  96. f"Make sure that this port is open for incoming requests.")
  97. if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
  98. if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
  99. response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS:
  100. raise ValidationError(f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
  101. f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})")
  102. except ValidationError as e:
  103. if strict:
  104. raise
  105. else:
  106. logger.warning(repr(e))
  107. peer_id = DHTID.from_bytes(response.peer.node_id) if responded else None
  108. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
  109. return peer_id
  110. async def get_outgoing_request_endpoint(self, peer: Endpoint) -> Optional[Endpoint]:
  111. """ ask this peer how it perceives this node's outgoing request address """
  112. try:
  113. async with self.rpc_semaphore:
  114. ping_request = dht_pb2.PingRequest(peer=None, validate=False)
  115. response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
  116. if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
  117. return response.sender_endpoint
  118. except grpc.aio.AioRpcError as error:
  119. logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
  120. async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
  121. """ Some node wants us to add it to our routing table. """
  122. response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
  123. dht_time=get_dht_time(), available=False)
  124. if request.peer and request.peer.node_id and request.peer.rpc_port:
  125. sender_id = DHTID.from_bytes(request.peer.node_id)
  126. if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
  127. sender_endpoint = request.peer.endpoint # if peer has preferred endpoint, use it
  128. else:
  129. sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
  130. response.sender_endpoint = sender_endpoint
  131. if request.validate:
  132. response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
  133. asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
  134. responded=response.available or not request.validate))
  135. return response
  136. async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
  137. values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
  138. expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
  139. subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
  140. in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Optional[List[bool]]:
  141. """
  142. Ask a recipient to store several (key, value : expiration_time) items or update their older value
  143. :param peer: request this peer to store the data
  144. :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
  145. :param values: a list of N serialized values (bytes) for each respective key
  146. :param expiration_time: a list of N expiration timestamps for each respective key-value pair(see get_dht_time())
  147. :param subkeys: a list of N optional sub-keys. If None, stores value normally. If not subkey is not None:
  148. 1) if local storage doesn't have :key:, create a new dictionary {subkey: (value, expiration_time)}
  149. 2) if local storage already has a dictionary under :key:, try add (subkey, value, exp_time) to that dictionary
  150. 2) if local storage associates :key: with a normal value with smaller expiration, clear :key: and perform (1)
  151. 3) finally, if local storage currently associates :key: with a normal value with larger expiration, do nothing
  152. :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
  153. :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
  154. until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
  155. :return: list of [True / False] True = stored, False = failed (found newer value or no response)
  156. if peer did not respond (e.g. due to timeout or congestion), returns None
  157. """
  158. if isinstance(expiration_time, DHTExpiration):
  159. expiration_time = [expiration_time] * len(keys)
  160. if subkeys is None:
  161. subkeys = [None] * len(keys)
  162. in_cache = in_cache if in_cache is not None else [False] * len(keys) # default value (None)
  163. in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache # single bool
  164. keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
  165. for i in range(len(keys)):
  166. if subkeys[i] is None: # add default sub-key if not specified
  167. subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE
  168. else:
  169. subkeys[i] = self.serializer.dumps(subkeys[i])
  170. if isinstance(values[i], DictionaryDHTValue):
  171. assert subkeys[i] == self.IS_DICTIONARY, "Please don't specify subkey when storing an entire dictionary"
  172. values[i] = self.serializer.dumps(values[i])
  173. assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
  174. store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), subkeys=subkeys, values=values,
  175. expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
  176. try:
  177. async with self.rpc_semaphore:
  178. response = await self._get_dht_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
  179. if response.peer and response.peer.node_id:
  180. peer_id = DHTID.from_bytes(response.peer.node_id)
  181. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
  182. return response.store_ok
  183. except grpc.aio.AioRpcError as error:
  184. logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
  185. asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
  186. return None
  187. async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
  188. """ Some node wants us to store this (key, value) pair """
  189. if request.peer: # if requested, add peer to the routing table
  190. asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
  191. assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
  192. response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
  193. keys = map(DHTID.from_bytes, request.keys)
  194. for key_id, tag, value_bytes, expiration_time, in_cache in zip(
  195. keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
  196. storage = self.cache if in_cache else self.storage
  197. if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys
  198. response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
  199. elif tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys
  200. value_dictionary = self.serializer.loads(value_bytes)
  201. assert isinstance(value_dictionary, DictionaryDHTValue)
  202. response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
  203. for subkey, item in value_dictionary.items()))
  204. else: # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
  205. subkey = self.serializer.loads(tag)
  206. response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
  207. return response
  208. async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[Dict[
  209. DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
  210. """
  211. Request keys from a peer. For each key, look for its (value, expiration time) locally and
  212. k additional peers that are most likely to have this key (ranked by XOR distance)
  213. :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
  214. value: value stored by the recipient with that key, or None if peer doesn't have this value
  215. expiration time: expiration time of the returned value, None if no value was found
  216. neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
  217. If peer didn't respond, returns None
  218. """
  219. keys = list(keys)
  220. find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
  221. try:
  222. async with self.rpc_semaphore:
  223. response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
  224. if response.peer and response.peer.node_id:
  225. peer_id = DHTID.from_bytes(response.peer.node_id)
  226. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
  227. assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys"
  228. output = {} # unpack data depending on its type
  229. for key, result in zip(keys, response.results):
  230. nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
  231. if result.type == dht_pb2.NOT_FOUND:
  232. output[key] = None, nearest
  233. elif result.type == dht_pb2.FOUND_REGULAR:
  234. output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
  235. elif result.type == dht_pb2.FOUND_DICTIONARY:
  236. deserialized_dictionary = self.serializer.loads(result.value)
  237. output[key] = ValueWithExpiration(deserialized_dictionary, result.expiration_time), nearest
  238. else:
  239. logger.error(f"Unknown result type: {result.type}")
  240. return output
  241. except grpc.aio.AioRpcError as error:
  242. logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
  243. asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
  244. async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
  245. """
  246. Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
  247. Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
  248. """
  249. if request.peer: # if requested, add peer to the routing table
  250. asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
  251. response = dht_pb2.FindResponse(results=[], peer=self.node_info)
  252. for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
  253. maybe_item = self.storage.get(key_id)
  254. cached_item = self.cache.get(key_id)
  255. if cached_item is not None and (maybe_item is None
  256. or cached_item.expiration_time > maybe_item.expiration_time):
  257. maybe_item = cached_item
  258. if maybe_item is None: # value not found
  259. item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
  260. elif isinstance(maybe_item.value, DictionaryDHTValue):
  261. item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value),
  262. expiration_time=maybe_item.expiration_time)
  263. else: # found regular value
  264. item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value,
  265. expiration_time=maybe_item.expiration_time)
  266. for node_id, endpoint in self.routing_table.get_nearest_neighbors(
  267. key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
  268. item.nearest_node_ids.append(node_id.to_bytes())
  269. item.nearest_endpoints.append(endpoint)
  270. response.results.append(item)
  271. return response
  272. async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
  273. """
  274. This method is called on every incoming AND outgoing request to update the routing table
  275. :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
  276. :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
  277. :param responded: for outgoing requests, this indicated whether recipient responded or not.
  278. For incoming requests, this should always be True
  279. """
  280. node_id = node_id if node_id is not None else self.routing_table.get(endpoint=peer_endpoint)
  281. if responded: # incoming request or outgoing request with response
  282. if node_id not in self.routing_table:
  283. # we just met a new node, maybe we know some values that it *should* store
  284. data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
  285. for key, item in list(self.storage.items()):
  286. neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
  287. if neighbors:
  288. nearest_distance = neighbors[0][0].xor_distance(key)
  289. farthest_distance = neighbors[-1][0].xor_distance(key)
  290. new_node_should_store = node_id.xor_distance(key) < farthest_distance
  291. this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
  292. if not neighbors or (new_node_should_store and this_node_is_responsible):
  293. data_to_send.append((key, item.value, item.expiration_time))
  294. if data_to_send:
  295. asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
  296. maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_endpoint)
  297. if maybe_node_to_ping is not None:
  298. # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
  299. # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
  300. asyncio.create_task(self.call_ping(maybe_node_to_ping[1])) # [1]-th element is that node's endpoint
  301. else: # we sent outgoing request and peer did not respond
  302. if node_id is not None and node_id in self.routing_table:
  303. del self.routing_table[node_id]
  304. class ValidationError(Exception):
  305. """ This exception is thrown if DHT node didn't pass validation by other nodes. """