test_dht_crypto.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import dataclasses
  2. import pytest
  3. from hivemind.dht import get_dht_time
  4. from hivemind.dht.crypto import RSASignatureValidator
  5. from hivemind.dht.validation import DHTRecord
  6. def test_rsa_signature_validator():
  7. receiver_validator = RSASignatureValidator()
  8. sender_validator = RSASignatureValidator()
  9. mallory_validator = RSASignatureValidator()
  10. plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
  11. expiration_time=get_dht_time() + 10)
  12. protected_records = [
  13. dataclasses.replace(plain_record,
  14. key=plain_record.key + sender_validator.ownership_marker),
  15. dataclasses.replace(plain_record,
  16. subkey=plain_record.subkey + sender_validator.ownership_marker),
  17. ]
  18. # test 1: Non-protected record (no signature added)
  19. assert sender_validator.sign_value(plain_record) == plain_record.value
  20. assert receiver_validator.validate(plain_record)
  21. # test 2: Correct signatures
  22. signed_records = [dataclasses.replace(record, value=sender_validator.sign_value(record))
  23. for record in protected_records]
  24. for record in signed_records:
  25. assert receiver_validator.validate(record)
  26. assert receiver_validator.strip_value(record) == b'value'
  27. # test 3: Invalid signatures
  28. signed_records = protected_records # Without signature
  29. signed_records += [dataclasses.replace(record,
  30. value=record.value + b'[signature:INVALID_BYTES]')
  31. for record in protected_records] # With invalid signature
  32. signed_records += [dataclasses.replace(record, value=mallory_validator.sign_value(record))
  33. for record in protected_records] # With someone else's signature
  34. for record in signed_records:
  35. assert not receiver_validator.validate(record)