test_dht_crypto.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import dataclasses
  2. import pickle
  3. import multiprocessing as mp
  4. import pytest
  5. import hivemind
  6. from hivemind.utils.timed_storage import get_dht_time
  7. from hivemind.dht.crypto import RSASignatureValidator
  8. from hivemind.dht.node import DHTNode
  9. from hivemind.dht.validation import DHTRecord
  10. from hivemind.utils.crypto import RSAPrivateKey
  11. def test_rsa_signature_validator():
  12. receiver_validator = RSASignatureValidator()
  13. sender_validator = RSASignatureValidator(RSAPrivateKey())
  14. mallory_validator = RSASignatureValidator(RSAPrivateKey())
  15. plain_record = DHTRecord(key=b"key", subkey=b"subkey", value=b"value", expiration_time=get_dht_time() + 10)
  16. protected_records = [
  17. dataclasses.replace(plain_record, key=plain_record.key + sender_validator.local_public_key),
  18. dataclasses.replace(plain_record, subkey=plain_record.subkey + sender_validator.local_public_key),
  19. ]
  20. # test 1: Non-protected record (no signature added)
  21. assert sender_validator.sign_value(plain_record) == plain_record.value
  22. assert receiver_validator.validate(plain_record)
  23. # test 2: Correct signatures
  24. signed_records = [
  25. dataclasses.replace(record, value=sender_validator.sign_value(record)) for record in protected_records
  26. ]
  27. for record in signed_records:
  28. assert receiver_validator.validate(record)
  29. assert receiver_validator.strip_value(record) == b"value"
  30. # test 3: Invalid signatures
  31. signed_records = protected_records # Without signature
  32. signed_records += [
  33. dataclasses.replace(record, value=record.value + b"[signature:INVALID_BYTES]") for record in protected_records
  34. ] # With invalid signature
  35. signed_records += [
  36. dataclasses.replace(record, value=mallory_validator.sign_value(record)) for record in protected_records
  37. ] # With someone else's signature
  38. for record in signed_records:
  39. assert not receiver_validator.validate(record)
  40. def test_cached_key():
  41. first_validator = RSASignatureValidator()
  42. second_validator = RSASignatureValidator()
  43. assert first_validator.local_public_key == second_validator.local_public_key
  44. third_validator = RSASignatureValidator(RSAPrivateKey())
  45. assert first_validator.local_public_key != third_validator.local_public_key
  46. def test_validator_instance_is_picklable():
  47. # Needs to be picklable because the validator instance may be sent between processes
  48. original_validator = RSASignatureValidator()
  49. unpickled_validator = pickle.loads(pickle.dumps(original_validator))
  50. # To check that the private key was pickled and unpickled correctly, we sign a record
  51. # with the original public key using the unpickled validator and then validate the signature
  52. record = DHTRecord(
  53. key=b"key",
  54. subkey=b"subkey" + original_validator.local_public_key,
  55. value=b"value",
  56. expiration_time=get_dht_time() + 10,
  57. )
  58. signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
  59. assert b"[signature:" in signed_record.value
  60. assert original_validator.validate(signed_record)
  61. assert unpickled_validator.validate(signed_record)
  62. def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
  63. validator = conn.recv()
  64. record = conn.recv()
  65. record = dataclasses.replace(record, value=validator.sign_value(record))
  66. conn.send(record)
  67. return record
  68. def test_signing_in_different_process():
  69. parent_conn, child_conn = mp.Pipe()
  70. process = mp.Process(target=get_signed_record, args=[child_conn])
  71. process.start()
  72. validator = RSASignatureValidator()
  73. parent_conn.send(validator)
  74. record = DHTRecord(
  75. key=b"key", subkey=b"subkey" + validator.local_public_key, value=b"value", expiration_time=get_dht_time() + 10
  76. )
  77. parent_conn.send(record)
  78. signed_record = parent_conn.recv()
  79. assert b"[signature:" in signed_record.value
  80. assert validator.validate(signed_record)
  81. @pytest.mark.forked
  82. @pytest.mark.asyncio
  83. async def test_dhtnode_signatures():
  84. alice = await DHTNode.create(record_validator=RSASignatureValidator())
  85. initial_peers = await alice.get_visible_maddrs()
  86. bob = await DHTNode.create(record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
  87. mallory = await DHTNode.create(
  88. record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers
  89. )
  90. key = b"key"
  91. subkey = b"protected_subkey" + bob.protocol.record_validator.local_public_key
  92. assert await bob.store(key, b"true_value", hivemind.get_dht_time() + 10, subkey=subkey)
  93. assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
  94. store_ok = await mallory.store(key, b"fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
  95. assert not store_ok
  96. assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
  97. assert await bob.store(key, b"updated_true_value", hivemind.get_dht_time() + 10, subkey=subkey)
  98. assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"
  99. await bob.shutdown() # Bob has shut down, now Mallory is the single peer of Alice
  100. store_ok = await mallory.store(key, b"updated_fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
  101. assert not store_ok
  102. assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"