123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- """ RPC protocol that provides nodes a way to communicate with each other. Based on gRPC.AIO. """
- from __future__ import annotations
- import asyncio
- from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
- import grpc
- from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
- from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
- from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
- from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
- from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
- logger = get_logger(__name__)
- class DHTProtocol(dht_grpc.DHTServicer):
- # fmt:off
- node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
- channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
- storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
- # fmt:on
- serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
- RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
- @classmethod
- async def create(
- cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
- parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
- listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
- channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
- """
- A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
- As a side-effect, DHTProtocol also maintains a routing table as described in
- https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
- See DHTNode (node.py) for a more detailed description.
- :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
- for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
- Only the call_* methods are meant to be called publicly, e.g. from DHTNode
- Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
- """
- self = cls(_initialized_with_create=True)
- self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
- self.wait_timeout, self.channel_options = wait_timeout, tuple(channel_options)
- self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
- self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
- self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
- if listen: # set up server to process incoming rpc requests
- grpc.aio.init_grpc_aio()
- self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
- dht_grpc.add_DHTServicer_to_server(self, self.server)
- self.port = self.server.add_insecure_port(listen_on)
- assert self.port != 0, f"Failed to listen to {listen_on}"
- if endpoint is not None and endpoint.endswith('*'):
- endpoint = replace_port(endpoint, self.port)
- self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
- endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
- await self.server.start()
- else: # not listening to incoming requests, client-only mode
- # note: use empty node_info so peers won't add you to their routing tables
- self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
- if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
- logger.warning(f"DHTProtocol has no server (due to listen=False), listen_on"
- f"and kwargs have no effect (unused kwargs: {kwargs})")
- return self
- def __init__(self, *, _initialized_with_create=False):
- """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
- assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
- super().__init__()
- async def shutdown(self, timeout=None):
- """ Process existing requests, close all connections and stop the server """
- if self.server:
- await self.server.stop(timeout)
- else:
- logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
- def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
- """ get a DHTStub that sends requests to a given peer """
- return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
- async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
- """
- Get peer's node id and add him to the routing table. If peer doesn't respond, return None
- :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
- :param validate: if True, validates that node's endpoint is available
- :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
- :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
- :return: node's DHTID, if peer responded and decided to send his node_id
- """
- try:
- async with self.rpc_semaphore:
- ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate)
- time_requested = get_dht_time()
- response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
- time_responded = get_dht_time()
- except grpc.aio.AioRpcError as error:
- logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
- response = None
- responded = bool(response and response.peer and response.peer.node_id)
- if responded and validate:
- try:
- if self.server is not None and not response.available:
- raise ValidationError(f"Peer {peer} couldn't access this node at {response.sender_endpoint} . "
- f"Make sure that this port is open for incoming requests.")
- if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
- if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
- response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS:
- raise ValidationError(f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
- f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})")
- except ValidationError as e:
- if strict:
- raise
- else:
- logger.warning(repr(e))
- peer_id = DHTID.from_bytes(response.peer.node_id) if responded else None
- asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
- return peer_id
- async def get_outgoing_request_endpoint(self, peer: Endpoint) -> Optional[Endpoint]:
- """ ask this peer how it perceives this node's outgoing request address """
- try:
- async with self.rpc_semaphore:
- ping_request = dht_pb2.PingRequest(peer=None, validate=False)
- response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
- if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
- return response.sender_endpoint
- except grpc.aio.AioRpcError as error:
- logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
- async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
- """ Some node wants us to add it to our routing table. """
- response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
- dht_time=get_dht_time(), available=False)
- if request.peer and request.peer.node_id and request.peer.rpc_port:
- sender_id = DHTID.from_bytes(request.peer.node_id)
- if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
- sender_endpoint = request.peer.endpoint # if peer has preferred endpoint, use it
- else:
- sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
- response.sender_endpoint = sender_endpoint
- if request.validate:
- response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
- asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
- responded=response.available or not request.validate))
- return response
- async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
- values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
- expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
- subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
- in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Optional[List[bool]]:
- """
- Ask a recipient to store several (key, value : expiration_time) items or update their older value
- :param peer: request this peer to store the data
- :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
- :param values: a list of N serialized values (bytes) for each respective key
- :param expiration_time: a list of N expiration timestamps for each respective key-value pair(see get_dht_time())
- :param subkeys: a list of N optional sub-keys. If None, stores value normally. If not subkey is not None:
- 1) if local storage doesn't have :key:, create a new dictionary {subkey: (value, expiration_time)}
- 2) if local storage already has a dictionary under :key:, try add (subkey, value, exp_time) to that dictionary
- 2) if local storage associates :key: with a normal value with smaller expiration, clear :key: and perform (1)
- 3) finally, if local storage currently associates :key: with a normal value with larger expiration, do nothing
- :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
- :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
- until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
- :return: list of [True / False] True = stored, False = failed (found newer value or no response)
- if peer did not respond (e.g. due to timeout or congestion), returns None
- """
- if isinstance(expiration_time, DHTExpiration):
- expiration_time = [expiration_time] * len(keys)
- if subkeys is None:
- subkeys = [None] * len(keys)
- in_cache = in_cache if in_cache is not None else [False] * len(keys) # default value (None)
- in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache # single bool
- keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
- for i in range(len(keys)):
- if subkeys[i] is None: # add default sub-key if not specified
- subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE
- else:
- subkeys[i] = self.serializer.dumps(subkeys[i])
- if isinstance(values[i], DictionaryDHTValue):
- assert subkeys[i] == self.IS_DICTIONARY, "Please don't specify subkey when storing an entire dictionary"
- values[i] = self.serializer.dumps(values[i])
- assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
- store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), subkeys=subkeys, values=values,
- expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
- try:
- async with self.rpc_semaphore:
- response = await self._get_dht_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
- if response.peer and response.peer.node_id:
- peer_id = DHTID.from_bytes(response.peer.node_id)
- asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
- return response.store_ok
- except grpc.aio.AioRpcError as error:
- logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
- asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
- return None
- async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
- """ Some node wants us to store this (key, value) pair """
- if request.peer: # if requested, add peer to the routing table
- asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
- assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
- response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
- keys = map(DHTID.from_bytes, request.keys)
- for key_id, tag, value_bytes, expiration_time, in_cache in zip(
- keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
- storage = self.cache if in_cache else self.storage
- if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys
- response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
- elif tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys
- value_dictionary = self.serializer.loads(value_bytes)
- assert isinstance(value_dictionary, DictionaryDHTValue)
- response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
- for subkey, item in value_dictionary.items()))
- else: # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
- subkey = self.serializer.loads(tag)
- response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
- return response
- async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[Dict[
- DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
- """
- Request keys from a peer. For each key, look for its (value, expiration time) locally and
- k additional peers that are most likely to have this key (ranked by XOR distance)
- :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
- value: value stored by the recipient with that key, or None if peer doesn't have this value
- expiration time: expiration time of the returned value, None if no value was found
- neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
- If peer didn't respond, returns None
- """
- keys = list(keys)
- find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
- try:
- async with self.rpc_semaphore:
- response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
- if response.peer and response.peer.node_id:
- peer_id = DHTID.from_bytes(response.peer.node_id)
- asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
- assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys"
- output = {} # unpack data depending on its type
- for key, result in zip(keys, response.results):
- nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
- if result.type == dht_pb2.NOT_FOUND:
- output[key] = None, nearest
- elif result.type == dht_pb2.FOUND_REGULAR:
- output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
- elif result.type == dht_pb2.FOUND_DICTIONARY:
- deserialized_dictionary = self.serializer.loads(result.value)
- output[key] = ValueWithExpiration(deserialized_dictionary, result.expiration_time), nearest
- else:
- logger.error(f"Unknown result type: {result.type}")
- return output
- except grpc.aio.AioRpcError as error:
- logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
- asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
- async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
- """
- Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
- Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
- """
- if request.peer: # if requested, add peer to the routing table
- asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
- response = dht_pb2.FindResponse(results=[], peer=self.node_info)
- for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
- maybe_item = self.storage.get(key_id)
- cached_item = self.cache.get(key_id)
- if cached_item is not None and (maybe_item is None
- or cached_item.expiration_time > maybe_item.expiration_time):
- maybe_item = cached_item
- if maybe_item is None: # value not found
- item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
- elif isinstance(maybe_item.value, DictionaryDHTValue):
- item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value),
- expiration_time=maybe_item.expiration_time)
- else: # found regular value
- item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value,
- expiration_time=maybe_item.expiration_time)
- for node_id, endpoint in self.routing_table.get_nearest_neighbors(
- key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
- item.nearest_node_ids.append(node_id.to_bytes())
- item.nearest_endpoints.append(endpoint)
- response.results.append(item)
- return response
- async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
- """
- This method is called on every incoming AND outgoing request to update the routing table
- :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
- :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
- :param responded: for outgoing requests, this indicated whether recipient responded or not.
- For incoming requests, this should always be True
- """
- node_id = node_id if node_id is not None else self.routing_table.get(endpoint=peer_endpoint)
- if responded: # incoming request or outgoing request with response
- if node_id not in self.routing_table:
- # we just met a new node, maybe we know some values that it *should* store
- data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
- for key, item in list(self.storage.items()):
- neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
- if neighbors:
- nearest_distance = neighbors[0][0].xor_distance(key)
- farthest_distance = neighbors[-1][0].xor_distance(key)
- new_node_should_store = node_id.xor_distance(key) < farthest_distance
- this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
- if not neighbors or (new_node_should_store and this_node_is_responsible):
- data_to_send.append((key, item.value, item.expiration_time))
- if data_to_send:
- asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
- maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_endpoint)
- if maybe_node_to_ping is not None:
- # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
- # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
- asyncio.create_task(self.call_ping(maybe_node_to_ping[1])) # [1]-th element is that node's endpoint
- else: # we sent outgoing request and peer did not respond
- if node_id is not None and node_id in self.routing_table:
- del self.routing_table[node_id]
- class ValidationError(Exception):
- """ This exception is thrown if DHT node didn't pass validation by other nodes. """
|