protocol.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. """ RPC protocol that provides nodes a way to communicate with each other. Based on gRPC.AIO. """
  2. from __future__ import annotations
  3. import asyncio
  4. import heapq
  5. from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
  6. from warnings import warn
  7. import grpc
  8. import grpc.experimental.aio
  9. from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
  10. from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
  11. from hivemind.utils import Endpoint, get_logger, replace_port
  12. logger = get_logger(__name__)
  13. class DHTProtocol(dht_grpc.DHTServicer):
  14. # fmt:off
  15. node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
  16. channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
  17. storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
  18. # fmt:on
  19. @classmethod
  20. async def create(
  21. cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
  22. parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
  23. channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
  24. """
  25. A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
  26. As a side-effect, DHTProtocol also maintains a routing table as described in
  27. https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
  28. See DHTNode (node.py) for a more detailed description.
  29. :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
  30. for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
  31. Only the call_* methods are meant to be called publicly, e.g. from DHTNode
  32. Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
  33. """
  34. self = cls(_initialized_with_create=True)
  35. self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
  36. self.wait_timeout, self.channel_options = wait_timeout, channel_options
  37. self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
  38. self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
  39. self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
  40. if listen: # set up server to process incoming rpc requests
  41. grpc.experimental.aio.init_grpc_aio()
  42. self.server = grpc.experimental.aio.server(**kwargs)
  43. dht_grpc.add_DHTServicer_to_server(self, self.server)
  44. found_port = self.server.add_insecure_port(listen_on)
  45. assert found_port != 0, f"Failed to listen to {listen_on}"
  46. self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=found_port)
  47. self.port = found_port
  48. await self.server.start()
  49. else: # not listening to incoming requests, client-only mode
  50. # note: use empty node_info so peers wont add you to their routing tables
  51. self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
  52. if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
  53. warn(f"DHTProtocol has no server (due to listen=False), listen_on"
  54. f"and kwargs have no effect (unused kwargs: {kwargs})")
  55. return self
  56. def __init__(self, *, _initialized_with_create=False):
  57. """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
  58. assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
  59. super().__init__()
  60. async def shutdown(self, timeout=None):
  61. """ Process existing requests, close all connections and stop the server """
  62. if self.server:
  63. await self.server.stop(timeout)
  64. else:
  65. warn("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
  66. def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
  67. """ get a DHTStub that sends requests to a given peer """
  68. channel = grpc.experimental.aio.insecure_channel(peer, options=self.channel_options)
  69. return dht_grpc.DHTStub(channel)
  70. async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
  71. """
  72. Get peer's node id and add him to the routing table. If peer doesn't respond, return None
  73. :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
  74. :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
  75. :return: node's DHTID, if peer responded and decided to send his node_id
  76. """
  77. try:
  78. async with self.rpc_semaphore:
  79. peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
  80. except grpc.experimental.aio.AioRpcError as error:
  81. logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
  82. peer_info = None
  83. responded = bool(peer_info and peer_info.node_id)
  84. peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
  85. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
  86. return peer_id
  87. async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext):
  88. """ Some node wants us to add it to our routing table. """
  89. if peer_info.node_id and peer_info.rpc_port:
  90. sender_id = DHTID.from_bytes(peer_info.node_id)
  91. rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port)
  92. asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
  93. return self.node_info
  94. async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
  95. expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
  96. in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
  97. """
  98. Ask a recipient to store several (key, value : expiration_time) items or update their older value
  99. :param peer: request this peer to store the data
  100. :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
  101. :param values: a list of N serialized values (bytes) for each respective key
  102. :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
  103. :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
  104. :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
  105. until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
  106. :return: list of [True / False] True = stored, False = failed (found newer value or no response)
  107. if peer did not respond (e.g. due to timeout or congestion), returns None
  108. """
  109. if isinstance(expiration_time, DHTExpiration):
  110. expiration_time = [expiration_time] * len(keys)
  111. in_cache = in_cache if in_cache is not None else [False] * len(keys) # default value (None)
  112. in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache # single bool
  113. keys, values, expiration_time, in_cache = map(list, [keys, values, expiration_time, in_cache])
  114. assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
  115. store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
  116. expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
  117. try:
  118. async with self.rpc_semaphore:
  119. response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
  120. if response.peer and response.peer.node_id:
  121. peer_id = DHTID.from_bytes(response.peer.node_id)
  122. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
  123. return response.store_ok
  124. except grpc.experimental.aio.AioRpcError as error:
  125. logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
  126. asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
  127. return [False] * len(keys)
  128. async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
  129. """ Some node wants us to store this (key, value) pair """
  130. if request.peer: # if requested, add peer to the routing table
  131. asyncio.create_task(self.rpc_ping(request.peer, context))
  132. assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
  133. response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
  134. for key_bytes, value_bytes, expiration_time, in_cache in zip(
  135. request.keys, request.values, request.expiration_time, request.in_cache):
  136. local_memory = self.cache if in_cache else self.storage
  137. response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
  138. return response
  139. async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
  140. Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
  141. """
  142. Request keys from a peer. For each key, look for its (value, expiration time) locally and
  143. k additional peers that are most likely to have this key (ranked by XOR distance)
  144. :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
  145. value: value stored by the recipient with that key, or None if peer doesn't have this value
  146. expiration time: expiration time of the returned value, None if no value was found
  147. neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
  148. If peer didn't respond, returns None
  149. """
  150. keys = list(keys)
  151. find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
  152. try:
  153. async with self.rpc_semaphore:
  154. response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
  155. if response.peer and response.peer.node_id:
  156. peer_id = DHTID.from_bytes(response.peer.node_id)
  157. asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
  158. assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
  159. "DHTProtocol: response is not aligned with keys and/or expiration times"
  160. output = {} # unpack data without special NOT_FOUND_* values
  161. for key, value, expiration_time, nearest in zip(
  162. keys, response.values, response.expiration_time, response.nearest):
  163. value = value if value != _NOT_FOUND_VALUE else None
  164. expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
  165. nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
  166. output[key] = (value, expiration_time, nearest)
  167. return output
  168. except grpc.experimental.aio.AioRpcError as error:
  169. logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
  170. asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
  171. async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
  172. """
  173. Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
  174. Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
  175. """
  176. if request.peer: # if requested, add peer to the routing table
  177. asyncio.create_task(self.rpc_ping(request.peer, context))
  178. response = dht_pb2.FindResponse(values=[], expiration_time=[], nearest=[], peer=self.node_info)
  179. for key_id in map(DHTID.from_bytes, request.keys):
  180. maybe_value, maybe_expiration_time = self.storage.get(key_id)
  181. cached_value, cached_expiration_time = self.cache.get(key_id)
  182. if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
  183. maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
  184. nearest_neighbors = self.routing_table.get_nearest_neighbors(
  185. key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
  186. if nearest_neighbors:
  187. peer_ids, endpoints = zip(*nearest_neighbors)
  188. else:
  189. peer_ids, endpoints = [], []
  190. response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
  191. response.expiration_time.append(maybe_expiration_time if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
  192. response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
  193. return response
  194. async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
  195. """
  196. This method is called on every incoming AND outgoing request to update the routing table
  197. :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
  198. :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
  199. :param responded: for outgoing requests, this indicated whether recipient responded or not.
  200. For incoming requests, this should always be True
  201. """
  202. node_id = node_id if node_id is not None else self.routing_table.get(endpoint=peer_endpoint)
  203. if responded: # incoming request or outgoing request with response
  204. if node_id not in self.routing_table:
  205. # we just met a new node, maybe we know some values that it *should* store
  206. data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
  207. for key, value, expiration_time in list(self.storage.items()):
  208. neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
  209. if neighbors:
  210. nearest_distance = neighbors[0][0].xor_distance(key)
  211. farthest_distance = neighbors[-1][0].xor_distance(key)
  212. new_node_should_store = node_id.xor_distance(key) < farthest_distance
  213. this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
  214. if not neighbors or (new_node_should_store and this_node_is_responsible):
  215. data_to_send.append((key, value, expiration_time))
  216. if data_to_send:
  217. asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
  218. maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_endpoint)
  219. if maybe_node_to_ping is not None:
  220. # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
  221. # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
  222. asyncio.create_task(self.call_ping(maybe_node_to_ping[1])) # [1]-th element is that node's endpoint
  223. else: # we sent outgoing request and peer did not respond
  224. if node_id is not None and node_id in self.routing_table:
  225. del self.routing_table[node_id]
  226. _NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf') # internal values to represent that a value was not found
  227. class LocalStorage:
  228. """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration_time) """
  229. def __init__(self, maxsize: Optional[int] = None):
  230. self.cache_size = maxsize or float("inf")
  231. self.data = dict()
  232. self.expiration_heap = []
  233. self.key_to_heap = dict()
  234. def remove_outdated(self):
  235. while self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
  236. or len(self.expiration_heap) > self.cache_size):
  237. heap_entry = heapq.heappop(self.expiration_heap)
  238. key = heap_entry[1]
  239. if self.key_to_heap[key] == heap_entry:
  240. del self.data[key], self.key_to_heap[key]
  241. def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
  242. """
  243. Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
  244. :returns: True if new value was stored, False it was rejected (current value is newer)
  245. """
  246. if expiration_time < get_dht_time():
  247. return False
  248. self.key_to_heap[key] = (expiration_time, key)
  249. heapq.heappush(self.expiration_heap, (expiration_time, key))
  250. if key in self.data:
  251. if self.data[key][1] < expiration_time:
  252. self.data[key] = (value, expiration_time)
  253. return True
  254. return False
  255. self.data[key] = (value, expiration_time)
  256. self.remove_outdated()
  257. return True
  258. def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
  259. """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
  260. self.remove_outdated()
  261. if key in self.data:
  262. return self.data[key]
  263. return None, None
  264. def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
  265. """ Iterate over (key, value, expiration_time) tuples stored in this storage """
  266. self.remove_outdated()
  267. return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())