test_dht_schema.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import re
  2. import pytest
  3. from pydantic import BaseModel, StrictFloat, StrictInt, conint
  4. from typing import Dict
  5. import hivemind
  6. from hivemind.dht import get_dht_time
  7. from hivemind.dht.node import DHTNode, LOCALHOST
  8. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator, conbytes
  9. from hivemind.dht.validation import DHTRecord, RecordValidatorBase
  10. class SampleSchema(BaseModel):
  11. experiment_name: bytes
  12. n_batches: Dict[bytes, conint(ge=0, strict=True)]
  13. signed_data: Dict[BytesWithPublicKey, bytes]
  14. @pytest.fixture
  15. async def dht_nodes_with_schema():
  16. validator = SchemaValidator(SampleSchema)
  17. alice = await DHTNode.create(record_validator=validator)
  18. bob = await DHTNode.create(
  19. record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
  20. return alice, bob
  21. @pytest.mark.forked
  22. @pytest.mark.asyncio
  23. async def test_expecting_regular_value(dht_nodes_with_schema):
  24. alice, bob = dht_nodes_with_schema
  25. # Regular value (bytes) expected
  26. assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
  27. assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
  28. assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
  29. subkey=b'subkey')
  30. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  31. assert not await bob.store('experiment_name', [], get_dht_time() + 10)
  32. assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)
  33. for peer in [alice, bob]:
  34. assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
  35. @pytest.mark.forked
  36. @pytest.mark.asyncio
  37. async def test_expecting_dictionary(dht_nodes_with_schema):
  38. alice, bob = dht_nodes_with_schema
  39. # Dictionary (bytes -> non-negative int) expected
  40. assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
  41. assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
  42. assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
  43. assert not await bob.store('n_batches', 666, get_dht_time() + 10)
  44. assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
  45. assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)
  46. # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
  47. assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)
  48. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  49. assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
  50. assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
  51. assert not await bob.store('n_batches', [], get_dht_time() + 10)
  52. assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)
  53. # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
  54. assert not await bob.store('n_batches', '', get_dht_time() + 10)
  55. for peer in [alice, bob]:
  56. dictionary = (await peer.get('n_batches', latest=True)).value
  57. assert (len(dictionary) == 2 and
  58. dictionary[b'uid1'].value == 777 and
  59. dictionary[b'uid2'].value == 778)
  60. @pytest.mark.forked
  61. @pytest.mark.asyncio
  62. async def test_expecting_public_keys(dht_nodes_with_schema):
  63. alice, bob = dht_nodes_with_schema
  64. # Subkeys expected to contain a public key
  65. # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
  66. assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
  67. subkey=b'uid[owner:public-key]')
  68. assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
  69. subkey=b'uid-without-public-key')
  70. for peer in [alice, bob]:
  71. dictionary = (await peer.get('signed_data', latest=True)).value
  72. assert (len(dictionary) == 1 and
  73. dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
  74. @pytest.mark.forked
  75. @pytest.mark.asyncio
  76. async def test_keys_outside_schema(dht_nodes_with_schema):
  77. class Schema(BaseModel):
  78. some_field: StrictInt
  79. class MergedSchema(BaseModel):
  80. another_field: StrictInt
  81. for allow_extra_keys in [False, True]:
  82. validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
  83. assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
  84. alice = await DHTNode.create(record_validator=validator)
  85. bob = await DHTNode.create(
  86. record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
  87. store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
  88. assert store_ok == allow_extra_keys
  89. for peer in [alice, bob]:
  90. result = await peer.get('unknown_key', latest=True)
  91. if allow_extra_keys:
  92. assert result.value == b'foo_bar'
  93. else:
  94. assert result is None
  95. @pytest.mark.forked
  96. @pytest.mark.asyncio
  97. async def test_prefix():
  98. class Schema(BaseModel):
  99. field: StrictInt
  100. validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
  101. alice = await DHTNode.create(record_validator=validator)
  102. bob = await DHTNode.create(
  103. record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
  104. assert await bob.store('prefix_field', 777, get_dht_time() + 10)
  105. assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
  106. assert not await bob.store('field', 777, get_dht_time() + 10)
  107. for peer in [alice, bob]:
  108. assert (await peer.get('prefix_field', latest=True)).value == 777
  109. assert (await peer.get('field', latest=True)) is None
  110. @pytest.mark.forked
  111. @pytest.mark.asyncio
  112. async def test_merging_schema_validators(dht_nodes_with_schema):
  113. alice, bob = dht_nodes_with_schema
  114. class TrivialValidator(RecordValidatorBase):
  115. def validate(self, record: DHTRecord) -> bool:
  116. return True
  117. second_validator = TrivialValidator()
  118. # Can't merge with the validator of the different type
  119. assert not alice.protocol.record_validator.merge_with(second_validator)
  120. class SecondSchema(BaseModel):
  121. some_field: StrictInt
  122. another_field: str
  123. class ThirdSchema(BaseModel):
  124. another_field: StrictInt # Allow it to be a StrictInt as well
  125. for schema in [SecondSchema, ThirdSchema]:
  126. new_validator = SchemaValidator(schema, allow_extra_keys=False)
  127. for peer in [alice, bob]:
  128. assert peer.protocol.record_validator.merge_with(new_validator)
  129. assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
  130. assert await bob.store('some_field', 777, get_dht_time() + 10)
  131. assert not await bob.store('some_field', 'string_value', get_dht_time() + 10)
  132. assert await bob.store('another_field', 42, get_dht_time() + 10)
  133. assert await bob.store('another_field', 'string_value', get_dht_time() + 10)
  134. # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
  135. assert await bob.store('unknown_key', 999, get_dht_time() + 10)
  136. for peer in [alice, bob]:
  137. assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
  138. assert (await peer.get('some_field', latest=True)).value == 777
  139. assert (await peer.get('another_field', latest=True)).value == 'string_value'
  140. assert (await peer.get('unknown_key', latest=True)).value == 999
  141. @pytest.mark.forked
  142. def test_sending_validator_instance_between_processes():
  143. alice = hivemind.DHT(start=True)
  144. bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"])
  145. alice.add_validators([SchemaValidator(SampleSchema)])
  146. bob.add_validators([SchemaValidator(SampleSchema)])
  147. assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
  148. assert not bob.store('experiment_name', 777, get_dht_time() + 10)
  149. assert alice.get('experiment_name', latest=True).value == b'foo_bar'