|
@@ -2,10 +2,12 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
+from itertools import zip_longest
|
|
|
from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
|
|
|
|
|
|
import grpc
|
|
|
|
|
|
+from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
|
|
|
from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
|
|
|
from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
|
|
|
from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
|
|
@@ -20,6 +22,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
|
|
|
node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
|
|
|
channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
|
|
|
storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
|
|
|
+ record_validator: Optional[RecordValidatorBase]
|
|
|
# fmt:on
|
|
|
|
|
|
serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
|
|
@@ -30,6 +33,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
|
|
|
cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
|
|
|
parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
|
|
|
listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
|
|
|
+ record_validator: Optional[RecordValidatorBase] = None,
|
|
|
channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
|
|
|
"""
|
|
|
A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
|
|
@@ -49,6 +53,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
|
|
|
self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
|
|
|
self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
|
|
|
self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
|
|
|
+ self.record_validator = record_validator
|
|
|
|
|
|
if listen: # set up server to process incoming rpc requests
|
|
|
grpc.aio.init_grpc_aio()
|
|
@@ -221,17 +226,28 @@ class DHTProtocol(dht_grpc.DHTServicer):
|
|
|
asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
|
|
|
assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
|
|
|
response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
|
|
|
- keys = map(DHTID.from_bytes, request.keys)
|
|
|
- for key_id, tag, value_bytes, expiration_time, in_cache in zip(
|
|
|
- keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
|
|
|
+ for key, tag, value_bytes, expiration_time, in_cache in zip(
|
|
|
+ request.keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
|
|
|
+ key_id = DHTID.from_bytes(key)
|
|
|
storage = self.cache if in_cache else self.storage
|
|
|
- if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys
|
|
|
- response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
|
|
|
- elif tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys
|
|
|
+
|
|
|
+ if tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys
|
|
|
value_dictionary = self.serializer.loads(value_bytes)
|
|
|
assert isinstance(value_dictionary, DictionaryDHTValue)
|
|
|
+ if not self._validate_dictionary(key, value_dictionary):
|
|
|
+ response.store_ok.append(False)
|
|
|
+ continue
|
|
|
+
|
|
|
response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
|
|
|
for subkey, item in value_dictionary.items()))
|
|
|
+ continue
|
|
|
+
|
|
|
+ if not self._validate_record(key, tag, value_bytes, expiration_time):
|
|
|
+ response.store_ok.append(False)
|
|
|
+ continue
|
|
|
+
|
|
|
+ if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys
|
|
|
+ response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
|
|
|
else: # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
|
|
|
subkey = self.serializer.loads(tag)
|
|
|
response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
|
|
@@ -261,15 +277,25 @@ class DHTProtocol(dht_grpc.DHTServicer):
|
|
|
|
|
|
output = {} # unpack data depending on its type
|
|
|
for key, result in zip(keys, response.results):
|
|
|
+ key_bytes = DHTID.to_bytes(key)
|
|
|
nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
|
|
|
|
|
|
if result.type == dht_pb2.NOT_FOUND:
|
|
|
output[key] = None, nearest
|
|
|
elif result.type == dht_pb2.FOUND_REGULAR:
|
|
|
+ if not self._validate_record(
|
|
|
+ key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time):
|
|
|
+ output[key] = None, nearest
|
|
|
+ continue
|
|
|
+
|
|
|
output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
|
|
|
elif result.type == dht_pb2.FOUND_DICTIONARY:
|
|
|
- deserialized_dictionary = self.serializer.loads(result.value)
|
|
|
- output[key] = ValueWithExpiration(deserialized_dictionary, result.expiration_time), nearest
|
|
|
+ value_dictionary = self.serializer.loads(result.value)
|
|
|
+ if not self._validate_dictionary(key_bytes, value_dictionary):
|
|
|
+ output[key] = None, nearest
|
|
|
+ continue
|
|
|
+
|
|
|
+ output[key] = ValueWithExpiration(value_dictionary, result.expiration_time), nearest
|
|
|
else:
|
|
|
logger.error(f"Unknown result type: {result.type}")
|
|
|
|
|
@@ -346,6 +372,26 @@ class DHTProtocol(dht_grpc.DHTServicer):
|
|
|
if node_id is not None and node_id in self.routing_table:
|
|
|
del self.routing_table[node_id]
|
|
|
|
|
|
+ def _validate_record(self, key_bytes: bytes, subkey_bytes: bytes, value_bytes: bytes,
|
|
|
+ expiration_time: float) -> bool:
|
|
|
+ if self.record_validator is None:
|
|
|
+ return True
|
|
|
+
|
|
|
+ record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
|
|
|
+ return self.record_validator.validate(record)
|
|
|
+
|
|
|
+ def _validate_dictionary(self, key_bytes: bytes, dictionary: DictionaryDHTValue) -> bool:
|
|
|
+ if self.record_validator is None:
|
|
|
+ return True
|
|
|
+
|
|
|
+ with dictionary.freeze():
|
|
|
+ for subkey, (value_bytes, expiration_time) in dictionary.items():
|
|
|
+ subkey_bytes = self.serializer.dumps(subkey)
|
|
|
+ record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
|
|
|
+ if not self.record_validator.validate(record):
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+
|
|
|
|
|
|
class ValidationError(Exception):
|
|
|
""" This exception is thrown if DHT node didn't pass validation by other nodes. """
|