test_dht_schema.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import asyncio
  2. from typing import Dict
  3. import pytest
  4. from pydantic import BaseModel, StrictInt, conint
  5. import hivemind
  6. from hivemind.dht.node import DHTNode
  7. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  8. from hivemind.dht.validation import DHTRecord, RecordValidatorBase
  9. from hivemind.utils.timed_storage import get_dht_time
  10. class SampleSchema(BaseModel):
  11. experiment_name: bytes
  12. n_batches: Dict[bytes, conint(ge=0, strict=True)]
  13. signed_data: Dict[BytesWithPublicKey, bytes]
  14. @pytest.fixture
  15. async def dht_nodes_with_schema():
  16. validator = SchemaValidator(SampleSchema)
  17. alice = await DHTNode.create(record_validator=validator)
  18. bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
  19. yield alice, bob
  20. await asyncio.gather(alice.shutdown(), bob.shutdown())
  21. @pytest.mark.forked
  22. @pytest.mark.asyncio
  23. async def test_expecting_regular_value(dht_nodes_with_schema):
  24. alice, bob = dht_nodes_with_schema
  25. # Regular value (bytes) expected
  26. assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
  27. assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
  28. assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
  29. subkey=b'subkey')
  30. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  31. assert not await bob.store('experiment_name', [], get_dht_time() + 10)
  32. assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)
  33. for peer in [alice, bob]:
  34. assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
  35. @pytest.mark.forked
  36. @pytest.mark.asyncio
  37. async def test_expecting_dictionary(dht_nodes_with_schema):
  38. alice, bob = dht_nodes_with_schema
  39. # Dictionary (bytes -> non-negative int) expected
  40. assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
  41. assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
  42. assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
  43. assert not await bob.store('n_batches', 666, get_dht_time() + 10)
  44. assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
  45. assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)
  46. # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
  47. assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)
  48. # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  49. assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
  50. assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
  51. assert not await bob.store('n_batches', [], get_dht_time() + 10)
  52. assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)
  53. # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
  54. assert not await bob.store('n_batches', '', get_dht_time() + 10)
  55. for peer in [alice, bob]:
  56. dictionary = (await peer.get('n_batches', latest=True)).value
  57. assert (len(dictionary) == 2 and
  58. dictionary[b'uid1'].value == 777 and
  59. dictionary[b'uid2'].value == 778)
  60. @pytest.mark.forked
  61. @pytest.mark.asyncio
  62. async def test_expecting_public_keys(dht_nodes_with_schema):
  63. alice, bob = dht_nodes_with_schema
  64. # Subkeys expected to contain a public key
  65. # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
  66. assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
  67. subkey=b'uid[owner:public-key]')
  68. assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
  69. subkey=b'uid-without-public-key')
  70. for peer in [alice, bob]:
  71. dictionary = (await peer.get('signed_data', latest=True)).value
  72. assert (len(dictionary) == 1 and
  73. dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
  74. @pytest.mark.forked
  75. @pytest.mark.asyncio
  76. async def test_keys_outside_schema(dht_nodes_with_schema):
  77. class Schema(BaseModel):
  78. some_field: StrictInt
  79. class MergedSchema(BaseModel):
  80. another_field: StrictInt
  81. for allow_extra_keys in [False, True]:
  82. validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
  83. assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
  84. alice = await DHTNode.create(record_validator=validator)
  85. bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
  86. store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
  87. assert store_ok == allow_extra_keys
  88. for peer in [alice, bob]:
  89. result = await peer.get('unknown_key', latest=True)
  90. if allow_extra_keys:
  91. assert result.value == b'foo_bar'
  92. else:
  93. assert result is None
  94. @pytest.mark.forked
  95. @pytest.mark.asyncio
  96. async def test_prefix():
  97. class Schema(BaseModel):
  98. field: StrictInt
  99. validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
  100. alice = await DHTNode.create(record_validator=validator)
  101. bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
  102. assert await bob.store('prefix_field', 777, get_dht_time() + 10)
  103. assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
  104. assert not await bob.store('field', 777, get_dht_time() + 10)
  105. for peer in [alice, bob]:
  106. assert (await peer.get('prefix_field', latest=True)).value == 777
  107. assert (await peer.get('field', latest=True)) is None
  108. await asyncio.gather(alice.shutdown(), bob.shutdown())
  109. @pytest.mark.forked
  110. @pytest.mark.asyncio
  111. async def test_merging_schema_validators(dht_nodes_with_schema):
  112. alice, bob = dht_nodes_with_schema
  113. class TrivialValidator(RecordValidatorBase):
  114. def validate(self, record: DHTRecord) -> bool:
  115. return True
  116. second_validator = TrivialValidator()
  117. # Can't merge with the validator of the different type
  118. assert not alice.protocol.record_validator.merge_with(second_validator)
  119. class SecondSchema(BaseModel):
  120. some_field: StrictInt
  121. another_field: str
  122. class ThirdSchema(BaseModel):
  123. another_field: StrictInt # Allow it to be a StrictInt as well
  124. for schema in [SecondSchema, ThirdSchema]:
  125. new_validator = SchemaValidator(schema, allow_extra_keys=False)
  126. for peer in [alice, bob]:
  127. assert peer.protocol.record_validator.merge_with(new_validator)
  128. assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
  129. assert await bob.store('some_field', 777, get_dht_time() + 10)
  130. assert not await bob.store('some_field', 'string_value', get_dht_time() + 10)
  131. assert await bob.store('another_field', 42, get_dht_time() + 10)
  132. assert await bob.store('another_field', 'string_value', get_dht_time() + 10)
  133. # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
  134. assert await bob.store('unknown_key', 999, get_dht_time() + 10)
  135. for peer in [alice, bob]:
  136. assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
  137. assert (await peer.get('some_field', latest=True)).value == 777
  138. assert (await peer.get('another_field', latest=True)).value == 'string_value'
  139. assert (await peer.get('unknown_key', latest=True)).value == 999
  140. @pytest.mark.forked
  141. def test_sending_validator_instance_between_processes():
  142. alice = hivemind.DHT(start=True)
  143. bob = hivemind.DHT(start=True, initial_peers=alice.get_visible_maddrs())
  144. alice.add_validators([SchemaValidator(SampleSchema)])
  145. bob.add_validators([SchemaValidator(SampleSchema)])
  146. assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
  147. assert not bob.store('experiment_name', 777, get_dht_time() + 10)
  148. assert alice.get('experiment_name', latest=True).value == b'foo_bar'
  149. alice.shutdown()
  150. bob.shutdown()