test_dht_crypto.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import dataclasses
  2. import multiprocessing as mp
  3. import pickle
  4. import pytest
  5. import hivemind
  6. from hivemind.dht.crypto import RSASignatureValidator
  7. from hivemind.dht.node import DHTNode
  8. from hivemind.dht.validation import DHTRecord
  9. from hivemind.utils.crypto import RSAPrivateKey
  10. from hivemind.utils.timed_storage import get_dht_time
  11. def test_rsa_signature_validator():
  12. receiver_validator = RSASignatureValidator()
  13. sender_validator = RSASignatureValidator(RSAPrivateKey())
  14. mallory_validator = RSASignatureValidator(RSAPrivateKey())
  15. plain_record = DHTRecord(key=b"key", subkey=b"subkey", value=b"value", expiration_time=get_dht_time() + 10)
  16. protected_records = [
  17. dataclasses.replace(plain_record, key=plain_record.key + sender_validator.local_public_key),
  18. dataclasses.replace(plain_record, subkey=plain_record.subkey + sender_validator.local_public_key),
  19. ]
  20. # test 1: Non-protected record (no signature added)
  21. assert sender_validator.sign_value(plain_record) == plain_record.value
  22. assert receiver_validator.validate(plain_record)
  23. # test 2: Correct signatures
  24. signed_records = [
  25. dataclasses.replace(record, value=sender_validator.sign_value(record)) for record in protected_records
  26. ]
  27. for record in signed_records:
  28. assert receiver_validator.validate(record)
  29. assert receiver_validator.strip_value(record) == b"value"
  30. # test 3: Invalid signatures
  31. signed_records = protected_records # Without signature
  32. signed_records += [
  33. dataclasses.replace(record, value=record.value + b"[signature:INVALID_BYTES]") for record in protected_records
  34. ] # With invalid signature
  35. signed_records += [
  36. dataclasses.replace(record, value=mallory_validator.sign_value(record)) for record in protected_records
  37. ] # With someone else's signature
  38. for record in signed_records:
  39. assert not receiver_validator.validate(record)
  40. def test_cached_key():
  41. first_validator = RSASignatureValidator()
  42. second_validator = RSASignatureValidator()
  43. assert first_validator.local_public_key == second_validator.local_public_key
  44. third_validator = RSASignatureValidator(RSAPrivateKey())
  45. assert first_validator.local_public_key != third_validator.local_public_key
  46. def test_validator_instance_is_picklable():
  47. # Needs to be picklable because the validator instance may be sent between processes
  48. original_validator = RSASignatureValidator()
  49. unpickled_validator = pickle.loads(pickle.dumps(original_validator))
  50. # To check that the private key was pickled and unpickled correctly, we sign a record
  51. # with the original public key using the unpickled validator and then validate the signature
  52. record = DHTRecord(
  53. key=b"key",
  54. subkey=b"subkey" + original_validator.local_public_key,
  55. value=b"value",
  56. expiration_time=get_dht_time() + 10,
  57. )
  58. signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
  59. assert b"[signature:" in signed_record.value
  60. assert original_validator.validate(signed_record)
  61. assert unpickled_validator.validate(signed_record)
  62. def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
  63. validator = conn.recv()
  64. record = conn.recv()
  65. record = dataclasses.replace(record, value=validator.sign_value(record))
  66. conn.send(record)
  67. return record
  68. def test_signing_in_different_process():
  69. parent_conn, child_conn = mp.Pipe()
  70. process = mp.Process(target=get_signed_record, args=[child_conn])
  71. process.start()
  72. validator = RSASignatureValidator()
  73. parent_conn.send(validator)
  74. record = DHTRecord(
  75. key=b"key", subkey=b"subkey" + validator.local_public_key, value=b"value", expiration_time=get_dht_time() + 10
  76. )
  77. parent_conn.send(record)
  78. signed_record = parent_conn.recv()
  79. assert b"[signature:" in signed_record.value
  80. assert validator.validate(signed_record)
  81. @pytest.mark.forked
  82. @pytest.mark.asyncio
  83. async def test_dhtnode_signatures():
  84. alice = await DHTNode.create(record_validator=RSASignatureValidator())
  85. initial_peers = await alice.get_visible_maddrs()
  86. bob = await DHTNode.create(record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
  87. mallory = await DHTNode.create(
  88. record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers
  89. )
  90. key = b"key"
  91. subkey = b"protected_subkey" + bob.protocol.record_validator.local_public_key
  92. assert await bob.store(key, b"true_value", hivemind.get_dht_time() + 10, subkey=subkey)
  93. assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
  94. store_ok = await mallory.store(key, b"fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
  95. assert not store_ok
  96. assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
  97. assert await bob.store(key, b"updated_true_value", hivemind.get_dht_time() + 10, subkey=subkey)
  98. assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"
  99. await bob.shutdown() # Bob has shut down, now Mallory is the single peer of Alice
  100. store_ok = await mallory.store(key, b"updated_fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
  101. assert not store_ok
  102. assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"