Browse Source

switch to uvloop, fix docs, fix congestion, add tests for empty routing table (#55)

* add test for RPC with empty RoutingTable

* rtfd config: do not install protoc

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* fail on warning

* install pip via environment.yaml

* typo

* add versions

* formatting

* switch DHT to SpawnProcess

* use uvloop

* rollback to default process

* rollback to default process

* rollback to default process

* rollback to default process

* rollback to default process

* rollback to default process

* SpawnProcess

* rollback SpawnProcess

* swap lines

* trigger rtfd

* re-implement max parallel_rpc

* re-implement max parallel_rpc

* specify parallel_rpc in tests

* pep8
justheuristic 5 năm trước cách đây
mục cha
commit
e9b176d600

+ 7 - 0
.readthedocs.yml

@@ -0,0 +1,7 @@
+version: 2
+
+sphinx:
+  fail_on_warning: true
+
+conda:
+  environment: docs/environment.yaml

+ 19 - 0
docs/environment.yaml

@@ -0,0 +1,19 @@
+channels:
+  - defaults
+  - anaconda
+  - pytorch
+  - conda-forge
+dependencies:
+  - grpcio
+  - grpcio-tools
+  - numpy>=1.14
+  - pytorch>=1.3.0
+  - joblib>=0.13
+  - pip
+  - pip:
+    - recommonmark
+    - sphinx_rtd_theme
+    - prefetch_generator>=1.0.1
+    - uvloop>=0.14.0
+    - umsgpack
+

+ 0 - 2
docs/requirements.txt

@@ -1,2 +0,0 @@
-recommonmark
-sphinx_rtd_theme

+ 2 - 1
hivemind/dht/__init__.py

@@ -11,6 +11,7 @@ import asyncio
 import multiprocessing as mp
 import warnings
 from typing import List, Optional
+import uvloop
 
 from .node import DHTNode, DHTID, DHTExpiration
 from .routing import get_dht_time
@@ -49,9 +50,9 @@ class DHT(mp.Process):
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         if asyncio.get_event_loop().is_running():
             asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+        uvloop.install()
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
-
         self.node = loop.run_until_complete(DHTNode.create(
             initial_peers=list(self.initial_peers), listen_on=f"{LOCALHOST}:{self.port}", **self.node_params))
         run_in_background(loop.run_forever)

+ 3 - 6
hivemind/dht/node.py

@@ -47,24 +47,22 @@ class DHTNode:
     protocol: DHTProtocol
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
 
-
     @classmethod
     async def create(
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
-            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, max_requests: int = 0,
+            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
-        :param port: port to which this DHTNode will listen, by default find some open port
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
         :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
           either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
           Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
         :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
         :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
-        :param max_requests: maximum number of outgoing RPC requests emitted by DHTProtocol in parallel
+        :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
           Reduce this value if your RPC requests register no response despite the peer sending the response.
         :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
@@ -81,7 +79,6 @@ class DHTNode:
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         """
-        assert max_requests == 0, "TODO(jheuristic) implement congestion!"
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
@@ -89,7 +86,7 @@ class DHTNode:
         self.refresh_timeout = refresh_timeout
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 cache_size, listen, listen_on, **kwargs)
+                                                 parallel_rpc, cache_size, listen, listen_on, **kwargs)
         self.port = self.protocol.port
 
 

+ 12 - 7
hivemind/dht/protocol.py

@@ -23,13 +23,14 @@ class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
-    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
+    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
 
     @classmethod
-    async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
-                     cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
-                     channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
+    async def create(
+            cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
+            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
+            channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         As a side-effect, DHTProtocol also maintains a routing table as described in
@@ -47,6 +48,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.wait_timeout, self.channel_options = wait_timeout, channel_options
         self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
+        self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
 
         if listen:  # set up server to process incoming rpc requests
             grpc.experimental.aio.init_grpc_aio()
@@ -92,7 +94,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         :return: node's DHTID, if peer responded and decided to send his node_id
         """
         try:
-            peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+            async with self.rpc_semaphore:
+                peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
@@ -135,7 +138,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
                                              expiration=expirations, in_cache=in_cache, peer=self.node_info)
         try:
-            response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
+            async with self.rpc_semaphore:
+                response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
@@ -172,7 +176,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         keys = list(keys)
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         try:
-            response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
+            async with self.rpc_semaphore:
+                response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))

+ 0 - 1
hivemind/runtime/__init__.py

@@ -5,7 +5,6 @@ from selectors import DefaultSelector, EVENT_READ
 from typing import Dict
 
 import torch
-import tqdm
 from prefetch_generator import BackgroundGenerator
 
 from .expert_backend import ExpertBackend

+ 1 - 2
requirements.txt

@@ -1,10 +1,9 @@
 torch>=1.3.0
 joblib>=0.13
 numpy>=1.17
-requests>=2.22.0
-tqdm
 prefetch_generator>=1.0.1
 pytest
 umsgpack
+uvloop>=0.14.0
 grpcio
 grpcio-tools>=1.30.0

+ 43 - 1
tests/test_dht.py

@@ -100,6 +100,48 @@ def test_dht_protocol():
     peer2_proc.terminate()
 
 
+def test_empty_table():
+    """ Test RPC methods with empty routing table """
+    peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
+    peer_proc.start(), peer_started.wait()
+    test_success = mp.Event()
+
+    def _tester():
+        # note: we run everything in a separate process to re-initialize all global states from scratch
+        # this helps us avoid undesirable side-effects when running multiple tests in sequence
+
+        loop = asyncio.get_event_loop()
+        protocol = loop.run_until_complete(DHTProtocol.create(
+            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
+
+        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+
+        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        assert recv_value_bytes is None and recv_expiration is None and len(nodes_found) == 0
+        assert all(loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        )), "peer rejected store"
+
+        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        assert len(nodes_found) == 0
+        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+
+        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
+        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
+        test_success.set()
+
+    tester = mp.Process(target=_tester, daemon=True)
+    tester.start()
+    tester.join()
+    assert test_success.is_set()
+    peer_proc.terminate()
+
+
 def run_node(node_id, peers, status_pipe: mp.Pipe):
     if asyncio.get_event_loop().is_running():
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
@@ -132,7 +174,7 @@ def test_dht():
         # note: we run everything in a separate process to re-initialize all global states from scratch
         # this helps us avoid undesirable side-effects when running multiple tests in sequence
         loop = asyncio.get_event_loop()
-        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5)))
+        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10))
 
         # test 1: find self
         nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=me.node_id, k_nearest=1))