test_dht_crypto.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import dataclasses
  2. import pickle
  3. import multiprocessing as mp
  4. import pytest
  5. import hivemind
  6. from hivemind.utils.timed_storage import get_dht_time
  7. from hivemind.dht.crypto import RSASignatureValidator
  8. from hivemind.dht.node import LOCALHOST, DHTNode
  9. from hivemind.dht.validation import DHTRecord
  10. from hivemind.utils.crypto import RSAPrivateKey
  11. def test_rsa_signature_validator():
  12. receiver_validator = RSASignatureValidator()
  13. sender_validator = RSASignatureValidator(RSAPrivateKey())
  14. mallory_validator = RSASignatureValidator(RSAPrivateKey())
  15. plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
  16. expiration_time=get_dht_time() + 10)
  17. protected_records = [
  18. dataclasses.replace(plain_record,
  19. key=plain_record.key + sender_validator.local_public_key),
  20. dataclasses.replace(plain_record,
  21. subkey=plain_record.subkey + sender_validator.local_public_key),
  22. ]
  23. # test 1: Non-protected record (no signature added)
  24. assert sender_validator.sign_value(plain_record) == plain_record.value
  25. assert receiver_validator.validate(plain_record)
  26. # test 2: Correct signatures
  27. signed_records = [dataclasses.replace(record, value=sender_validator.sign_value(record))
  28. for record in protected_records]
  29. for record in signed_records:
  30. assert receiver_validator.validate(record)
  31. assert receiver_validator.strip_value(record) == b'value'
  32. # test 3: Invalid signatures
  33. signed_records = protected_records # Without signature
  34. signed_records += [dataclasses.replace(record,
  35. value=record.value + b'[signature:INVALID_BYTES]')
  36. for record in protected_records] # With invalid signature
  37. signed_records += [dataclasses.replace(record, value=mallory_validator.sign_value(record))
  38. for record in protected_records] # With someone else's signature
  39. for record in signed_records:
  40. assert not receiver_validator.validate(record)
  41. def test_cached_key():
  42. first_validator = RSASignatureValidator()
  43. second_validator = RSASignatureValidator()
  44. assert first_validator.local_public_key == second_validator.local_public_key
  45. third_validator = RSASignatureValidator(RSAPrivateKey())
  46. assert first_validator.local_public_key != third_validator.local_public_key
  47. def test_validator_instance_is_picklable():
  48. # Needs to be picklable because the validator instance may be sent between processes
  49. original_validator = RSASignatureValidator()
  50. unpickled_validator = pickle.loads(pickle.dumps(original_validator))
  51. # To check that the private key was pickled and unpickled correctly, we sign a record
  52. # with the original public key using the unpickled validator and then validate the signature
  53. record = DHTRecord(key=b'key', subkey=b'subkey' + original_validator.local_public_key,
  54. value=b'value', expiration_time=get_dht_time() + 10)
  55. signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
  56. assert b'[signature:' in signed_record.value
  57. assert original_validator.validate(signed_record)
  58. assert unpickled_validator.validate(signed_record)
  59. def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
  60. validator = conn.recv()
  61. record = conn.recv()
  62. record = dataclasses.replace(record, value=validator.sign_value(record))
  63. conn.send(record)
  64. return record
  65. def test_signing_in_different_process():
  66. parent_conn, child_conn = mp.Pipe()
  67. process = mp.Process(target=get_signed_record, args=[child_conn])
  68. process.start()
  69. validator = RSASignatureValidator()
  70. parent_conn.send(validator)
  71. record = DHTRecord(key=b'key', subkey=b'subkey' + validator.local_public_key,
  72. value=b'value', expiration_time=get_dht_time() + 10)
  73. parent_conn.send(record)
  74. signed_record = parent_conn.recv()
  75. assert b'[signature:' in signed_record.value
  76. assert validator.validate(signed_record)
  77. @pytest.mark.forked
  78. @pytest.mark.asyncio
  79. async def test_dhtnode_signatures():
  80. alice = await DHTNode.create(record_validator=RSASignatureValidator())
  81. bob = await DHTNode.create(
  82. record_validator=RSASignatureValidator(RSAPrivateKey()),
  83. initial_peers=[f"{LOCALHOST}:{alice.port}"])
  84. mallory = await DHTNode.create(
  85. record_validator=RSASignatureValidator(RSAPrivateKey()),
  86. initial_peers=[f"{LOCALHOST}:{alice.port}"])
  87. key = b'key'
  88. subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key
  89. assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
  90. assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
  91. store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
  92. assert not store_ok
  93. assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
  94. assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
  95. assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
  96. await bob.shutdown() # Bob has shut down, now Mallory is the single peer of Alice
  97. store_ok = await mallory.store(key, b'updated_fake_value',
  98. hivemind.get_dht_time() + 10, subkey=subkey)
  99. assert not store_ok
  100. assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'