test_dht_schema.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import re
  2. import pydantic
  3. import pytest
  4. from pydantic import conint
  5. from typing import Dict
  6. from hivemind.dht import get_dht_time
  7. from hivemind.dht.node import DHTNode, LOCALHOST
  8. from hivemind.dht.schema import SchemaValidator, conbytes
  9. @pytest.fixture
  10. async def dht_nodes_with_schema():
  11. class Schema(pydantic.BaseModel):
  12. experiment_name: bytes
  13. n_batches: Dict[bytes, conint(ge=0, strict=True)]
  14. signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
  15. validator = SchemaValidator(Schema)
  16. alice = await DHTNode.create(record_validator=validator)
  17. bob = await DHTNode.create(
  18. record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
  19. return alice, bob
  20. @pytest.mark.forked
  21. @pytest.mark.asyncio
  22. async def test_keys_outside_schema(dht_nodes_with_schema):
  23. alice, bob = dht_nodes_with_schema
  24. assert await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
  25. for peer in [alice, bob]:
  26. assert (await peer.get(b'unknown_key', latest=True)).value == b'foo_bar'
  27. @pytest.mark.forked
  28. @pytest.mark.asyncio
  29. async def test_expecting_regular_value(dht_nodes_with_schema):
  30. alice, bob = dht_nodes_with_schema
  31. # Regular value (bytes) expected
  32. assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
  33. assert not await bob.store(b'experiment_name', 666, get_dht_time() + 10)
  34. assert not await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10,
  35. subkey=b'subkey')
  36. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  37. assert not await bob.store(b'experiment_name', [], get_dht_time() + 10)
  38. assert not await bob.store(b'experiment_name', [1, 2, 3], get_dht_time() + 10)
  39. for peer in [alice, bob]:
  40. assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
  41. @pytest.mark.forked
  42. @pytest.mark.asyncio
  43. async def test_expecting_dictionary(dht_nodes_with_schema):
  44. alice, bob = dht_nodes_with_schema
  45. # Dictionary (bytes -> non-negative int) expected
  46. assert await bob.store(b'n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
  47. assert await bob.store(b'n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
  48. assert not await bob.store(b'n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
  49. assert not await bob.store(b'n_batches', 666, get_dht_time() + 10)
  50. assert not await bob.store(b'n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
  51. assert not await bob.store(b'n_batches', 666, get_dht_time() + 10, subkey=666)
  52. # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
  53. assert not await bob.store(b'n_batches', {b'uid3': 779}, get_dht_time() + 10)
  54. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  55. assert not await bob.store(b'n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
  56. assert not await bob.store(b'n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
  57. assert not await bob.store(b'n_batches', [], get_dht_time() + 10)
  58. assert not await bob.store(b'n_batches', [(b'uid3', 779)], get_dht_time() + 10)
  59. # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
  60. assert not await bob.store(b'n_batches', '', get_dht_time() + 10)
  61. for peer in [alice, bob]:
  62. dictionary = (await peer.get(b'n_batches', latest=True)).value
  63. assert (len(dictionary) == 2 and
  64. dictionary[b'uid1'].value == 777 and
  65. dictionary[b'uid2'].value == 778)
  66. @pytest.mark.forked
  67. @pytest.mark.asyncio
  68. async def test_expecting_public_keys(dht_nodes_with_schema):
  69. alice, bob = dht_nodes_with_schema
  70. # Subkeys expected to contain a public key
  71. # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
  72. assert await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
  73. subkey=b'uid[owner:public-key]')
  74. assert not await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
  75. subkey=b'uid-without-public-key')
  76. for peer in [alice, bob]:
  77. dictionary = (await peer.get(b'signed_data', latest=True)).value
  78. assert (len(dictionary) == 1 and
  79. dictionary[b'uid[owner:public-key]'].value == b'foo_bar')