test_dht_schema.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import asyncio
  2. from typing import Dict
  3. import pytest
  4. from pydantic import BaseModel, StrictInt, conint
  5. import hivemind
  6. from hivemind.dht.node import DHTNode
  7. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  8. from hivemind.dht.validation import DHTRecord, RecordValidatorBase
  9. from hivemind.utils.timed_storage import get_dht_time
  10. class SampleSchema(BaseModel):
  11. experiment_name: bytes
  12. n_batches: Dict[bytes, conint(ge=0, strict=True)]
  13. signed_data: Dict[BytesWithPublicKey, bytes]
  14. @pytest.fixture
  15. async def dht_nodes_with_schema():
  16. validator = SchemaValidator(SampleSchema)
  17. alice = await DHTNode.create(record_validator=validator)
  18. bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
  19. yield alice, bob
  20. await asyncio.gather(alice.shutdown(), bob.shutdown())
  21. @pytest.mark.forked
  22. @pytest.mark.asyncio
  23. async def test_expecting_regular_value(dht_nodes_with_schema):
  24. alice, bob = dht_nodes_with_schema
  25. # Regular value (bytes) expected
  26. assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
  27. assert not await bob.store("experiment_name", 666, get_dht_time() + 10)
  28. assert not await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10, subkey=b"subkey")
  29. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  30. assert not await bob.store("experiment_name", [], get_dht_time() + 10)
  31. assert not await bob.store("experiment_name", [1, 2, 3], get_dht_time() + 10)
  32. for peer in [alice, bob]:
  33. assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
  34. @pytest.mark.forked
  35. @pytest.mark.asyncio
  36. async def test_expecting_dictionary(dht_nodes_with_schema):
  37. alice, bob = dht_nodes_with_schema
  38. # Dictionary (bytes -> non-negative int) expected
  39. assert await bob.store("n_batches", 777, get_dht_time() + 10, subkey=b"uid1")
  40. assert await bob.store("n_batches", 778, get_dht_time() + 10, subkey=b"uid2")
  41. assert not await bob.store("n_batches", -666, get_dht_time() + 10, subkey=b"uid3")
  42. assert not await bob.store("n_batches", 666, get_dht_time() + 10)
  43. assert not await bob.store("n_batches", b"not_integer", get_dht_time() + 10, subkey=b"uid1")
  44. assert not await bob.store("n_batches", 666, get_dht_time() + 10, subkey=666)
  45. # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
  46. assert not await bob.store("n_batches", {b"uid3": 779}, get_dht_time() + 10)
  47. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  48. assert not await bob.store("n_batches", 779.5, get_dht_time() + 10, subkey=b"uid3")
  49. assert not await bob.store("n_batches", 779.0, get_dht_time() + 10, subkey=b"uid3")
  50. assert not await bob.store("n_batches", [], get_dht_time() + 10)
  51. assert not await bob.store("n_batches", [(b"uid3", 779)], get_dht_time() + 10)
  52. # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
  53. assert not await bob.store("n_batches", "", get_dht_time() + 10)
  54. for peer in [alice, bob]:
  55. dictionary = (await peer.get("n_batches", latest=True)).value
  56. assert len(dictionary) == 2 and dictionary[b"uid1"].value == 777 and dictionary[b"uid2"].value == 778
  57. @pytest.mark.forked
  58. @pytest.mark.asyncio
  59. async def test_expecting_public_keys(dht_nodes_with_schema):
  60. alice, bob = dht_nodes_with_schema
  61. # Subkeys expected to contain a public key
  62. # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
  63. assert await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid[owner:public-key]")
  64. assert not await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid-without-public-key")
  65. for peer in [alice, bob]:
  66. dictionary = (await peer.get("signed_data", latest=True)).value
  67. assert len(dictionary) == 1 and dictionary[b"uid[owner:public-key]"].value == b"foo_bar"
  68. @pytest.mark.forked
  69. @pytest.mark.asyncio
  70. async def test_keys_outside_schema(dht_nodes_with_schema):
  71. class Schema(BaseModel):
  72. some_field: StrictInt
  73. class MergedSchema(BaseModel):
  74. another_field: StrictInt
  75. for allow_extra_keys in [False, True]:
  76. validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
  77. assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
  78. alice = await DHTNode.create(record_validator=validator)
  79. bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
  80. store_ok = await bob.store("unknown_key", b"foo_bar", get_dht_time() + 10)
  81. assert store_ok == allow_extra_keys
  82. for peer in [alice, bob]:
  83. result = await peer.get("unknown_key", latest=True)
  84. if allow_extra_keys:
  85. assert result.value == b"foo_bar"
  86. else:
  87. assert result is None
  88. @pytest.mark.forked
  89. @pytest.mark.asyncio
  90. async def test_prefix():
  91. class Schema(BaseModel):
  92. field: StrictInt
  93. validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")
  94. alice = await DHTNode.create(record_validator=validator)
  95. bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
  96. assert await bob.store("prefix_field", 777, get_dht_time() + 10)
  97. assert not await bob.store("prefix_field", "string_value", get_dht_time() + 10)
  98. assert not await bob.store("field", 777, get_dht_time() + 10)
  99. for peer in [alice, bob]:
  100. assert (await peer.get("prefix_field", latest=True)).value == 777
  101. assert (await peer.get("field", latest=True)) is None
  102. await asyncio.gather(alice.shutdown(), bob.shutdown())
  103. @pytest.mark.forked
  104. @pytest.mark.asyncio
  105. async def test_merging_schema_validators(dht_nodes_with_schema):
  106. alice, bob = dht_nodes_with_schema
  107. class TrivialValidator(RecordValidatorBase):
  108. def validate(self, record: DHTRecord) -> bool:
  109. return True
  110. second_validator = TrivialValidator()
  111. # Can't merge with the validator of the different type
  112. assert not alice.protocol.record_validator.merge_with(second_validator)
  113. class SecondSchema(BaseModel):
  114. some_field: StrictInt
  115. another_field: str
  116. class ThirdSchema(BaseModel):
  117. another_field: StrictInt # Allow it to be a StrictInt as well
  118. for schema in [SecondSchema, ThirdSchema]:
  119. new_validator = SchemaValidator(schema, allow_extra_keys=False)
  120. for peer in [alice, bob]:
  121. assert peer.protocol.record_validator.merge_with(new_validator)
  122. assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
  123. assert await bob.store("some_field", 777, get_dht_time() + 10)
  124. assert not await bob.store("some_field", "string_value", get_dht_time() + 10)
  125. assert await bob.store("another_field", 42, get_dht_time() + 10)
  126. assert await bob.store("another_field", "string_value", get_dht_time() + 10)
  127. # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
  128. assert await bob.store("unknown_key", 999, get_dht_time() + 10)
  129. for peer in [alice, bob]:
  130. assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
  131. assert (await peer.get("some_field", latest=True)).value == 777
  132. assert (await peer.get("another_field", latest=True)).value == "string_value"
  133. assert (await peer.get("unknown_key", latest=True)).value == 999
  134. @pytest.mark.forked
  135. def test_sending_validator_instance_between_processes():
  136. alice = hivemind.DHT(start=True)
  137. bob = hivemind.DHT(start=True, initial_peers=alice.get_visible_maddrs())
  138. alice.add_validators([SchemaValidator(SampleSchema)])
  139. bob.add_validators([SchemaValidator(SampleSchema)])
  140. assert bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
  141. assert not bob.store("experiment_name", 777, get_dht_time() + 10)
  142. assert alice.get("experiment_name", latest=True).value == b"foo_bar"
  143. alice.shutdown()
  144. bob.shutdown()