瀏覽代碼

Support edge cases for DHT key/subkey/value, add tests, update .gitignore for pb2 (#167)

* fix bug with subkey equals zero
* add autogenerated protobuf files to .gitignore
* test store and get "tricky" values in dht
Michael Diskin 4 年之前
父節點
當前提交
03aca7c479
共有 4 個文件被更改,包括 28 次插入3 次删除
  1. 3 0
      .gitignore
  2. 1 1
      hivemind/__init__.py
  3. 2 2
      hivemind/dht/node.py
  4. 22 0
      tests/test_dht_node.py

+ 3 - 0
.gitignore

@@ -75,3 +75,6 @@ debian/reproducible-experiment-platform
 debian/files
 *.substvars
 *.debhelper.log
+
+# protobuf stuff
+hivemind/proto/*_pb2*

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.9.2'
+__version__ = '0.9.3'

+ 2 - 2
hivemind/dht/node.py

@@ -338,9 +338,9 @@ class DHTNode:
             queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
         try:
-            await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # wait for items to be stored
+            await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
             assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-            return {(key, subkey) if subkey else key: status or False for (key, subkey), status in store_ok.items()}
+            return {(key, subkey) if subkey is not None else key: status or False for (key, subkey), status in store_ok.items()}
         except asyncio.CancelledError as e:
             store_task.cancel()
             raise e

+ 22 - 0
tests/test_dht_node.py

@@ -5,6 +5,7 @@ import heapq
 from typing import Optional
 import numpy as np
 import pytest
+from itertools import product
 
 import hivemind
 from typing import List, Dict
@@ -433,3 +434,24 @@ async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
     with pytest.raises(ValidationError):
         node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
                                               endpoint=fake_endpoint)
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_edge_cases():
+    peers = []
+    for i in range(5):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+
+    subkeys = [0, '', False, True, 'abyrvalg', 4555]
+    keys = subkeys + [()]
+    values = subkeys + [[]]
+    for key, subkey, value in product(keys, subkeys, values):
+        await random.choice(peers).store(key=key, subkey=subkey, value=value,
+                                         expiration_time=hivemind.get_dht_time() + 999),
+
+        stored = await random.choice(peers).get(key=key, latest=True)
+        assert stored is not None
+        assert subkey in stored.value
+        assert stored.value[subkey].value == value