Browse Source

switch to uvloop, fix docs, fix congestion, add tests for empty routing table (#55)

* add test for RPC with empty RoutingTable

* rtfd config: do not install protoc

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* docs: use conda binaries for grpc (combat timeout error)

* fail on warning

* install pip via environment.yaml

* typo

* add versions

* formatting

* switch DHT to SpawnProcess

* use uvloop

* rollback to default process

* rollback to default process

* rollback to default process

* rollback to default process

* rollback to default process

* rollback to default process

* SpawnProcess

* rollback SpawnProcess

* swap lines

* trigger rtfd

* re-implement max parallel_rpc

* re-implement max parallel_rpc

* specify parallel_rpc in tests

* pep8
justheuristic 5 years ago
parent
commit
e9b176d600

+ 7 - 0
.readthedocs.yml

@@ -0,0 +1,7 @@
+version: 2
+
+sphinx:
+  fail_on_warning: true
+
+conda:
+  environment: docs/environment.yaml

+ 19 - 0
docs/environment.yaml

@@ -0,0 +1,19 @@
+channels:
+  - defaults
+  - anaconda
+  - pytorch
+  - conda-forge
+dependencies:
+  - grpcio
+  - grpcio-tools
+  - numpy>=1.14
+  - pytorch>=1.3.0
+  - joblib>=0.13
+  - pip
+  - pip:
+    - recommonmark
+    - sphinx_rtd_theme
+    - prefetch_generator>=1.0.1
+    - uvloop>=0.14.0
+    - umsgpack
+

+ 0 - 2
docs/requirements.txt

@@ -1,2 +0,0 @@
-recommonmark
-sphinx_rtd_theme

+ 2 - 1
hivemind/dht/__init__.py

@@ -11,6 +11,7 @@ import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
 from typing import List, Optional
 from typing import List, Optional
+import uvloop
 
 
 from .node import DHTNode, DHTID, DHTExpiration
 from .node import DHTNode, DHTID, DHTExpiration
 from .routing import get_dht_time
 from .routing import get_dht_time
@@ -49,9 +50,9 @@ class DHT(mp.Process):
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         if asyncio.get_event_loop().is_running():
         if asyncio.get_event_loop().is_running():
             asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
             asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+        uvloop.install()
         loop = asyncio.new_event_loop()
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
         asyncio.set_event_loop(loop)
-
         self.node = loop.run_until_complete(DHTNode.create(
         self.node = loop.run_until_complete(DHTNode.create(
             initial_peers=list(self.initial_peers), listen_on=f"{LOCALHOST}:{self.port}", **self.node_params))
             initial_peers=list(self.initial_peers), listen_on=f"{LOCALHOST}:{self.port}", **self.node_params))
         run_in_background(loop.run_forever)
         run_in_background(loop.run_forever)

+ 3 - 6
hivemind/dht/node.py

@@ -47,24 +47,22 @@ class DHTNode:
     protocol: DHTProtocol
     protocol: DHTProtocol
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
 
 
-
     @classmethod
     @classmethod
     async def create(
     async def create(
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
-            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, max_requests: int = 0,
+            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
-        :param port: port to which this DHTNode will listen, by default find some open port
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
         :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
         :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
           either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
           either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
           Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
           Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
         :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
         :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
         :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
         :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
-        :param max_requests: maximum number of outgoing RPC requests emitted by DHTProtocol in parallel
+        :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
           Reduce this value if your RPC requests register no response despite the peer sending the response.
           Reduce this value if your RPC requests register no response despite the peer sending the response.
         :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
         :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
@@ -81,7 +79,6 @@ class DHTNode:
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         :param kwargs: extra parameters used in grpc.aio.server
         """
         """
-        assert max_requests == 0, "TODO(jheuristic) implement congestion!"
         self = cls(_initialized_with_create=True)
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
@@ -89,7 +86,7 @@ class DHTNode:
         self.refresh_timeout = refresh_timeout
         self.refresh_timeout = refresh_timeout
 
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 cache_size, listen, listen_on, **kwargs)
+                                                 parallel_rpc, cache_size, listen, listen_on, **kwargs)
         self.port = self.protocol.port
         self.port = self.protocol.port
 
 
 
 

+ 12 - 7
hivemind/dht/protocol.py

@@ -23,13 +23,14 @@ class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
-    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
+    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
     # fmt:on
 
 
     @classmethod
     @classmethod
-    async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
-                     cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
-                     channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
+    async def create(
+            cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
+            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
+            channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
         """
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         As a side-effect, DHTProtocol also maintains a routing table as described in
         As a side-effect, DHTProtocol also maintains a routing table as described in
@@ -47,6 +48,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.wait_timeout, self.channel_options = wait_timeout, channel_options
         self.wait_timeout, self.channel_options = wait_timeout, channel_options
         self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
         self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
+        self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
 
 
         if listen:  # set up server to process incoming rpc requests
         if listen:  # set up server to process incoming rpc requests
             grpc.experimental.aio.init_grpc_aio()
             grpc.experimental.aio.init_grpc_aio()
@@ -92,7 +94,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         :return: node's DHTID, if peer responded and decided to send his node_id
         :return: node's DHTID, if peer responded and decided to send his node_id
         """
         """
         try:
         try:
-            peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+            async with self.rpc_semaphore:
+                peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
             peer_info = None
@@ -135,7 +138,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
         store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
                                              expiration=expirations, in_cache=in_cache, peer=self.node_info)
                                              expiration=expirations, in_cache=in_cache, peer=self.node_info)
         try:
         try:
-            response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
+            async with self.rpc_semaphore:
+                response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
@@ -172,7 +176,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         keys = list(keys)
         keys = list(keys)
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         try:
         try:
-            response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
+            async with self.rpc_semaphore:
+                response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))

+ 0 - 1
hivemind/runtime/__init__.py

@@ -5,7 +5,6 @@ from selectors import DefaultSelector, EVENT_READ
 from typing import Dict
 from typing import Dict
 
 
 import torch
 import torch
-import tqdm
 from prefetch_generator import BackgroundGenerator
 from prefetch_generator import BackgroundGenerator
 
 
 from .expert_backend import ExpertBackend
 from .expert_backend import ExpertBackend

+ 1 - 2
requirements.txt

@@ -1,10 +1,9 @@
 torch>=1.3.0
 torch>=1.3.0
 joblib>=0.13
 joblib>=0.13
 numpy>=1.17
 numpy>=1.17
-requests>=2.22.0
-tqdm
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
 pytest
 pytest
 umsgpack
 umsgpack
+uvloop>=0.14.0
 grpcio
 grpcio
 grpcio-tools>=1.30.0
 grpcio-tools>=1.30.0

+ 43 - 1
tests/test_dht.py

@@ -100,6 +100,48 @@ def test_dht_protocol():
     peer2_proc.terminate()
     peer2_proc.terminate()
 
 
 
 
+def test_empty_table():
+    """ Test RPC methods with empty routing table """
+    peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
+    peer_proc.start(), peer_started.wait()
+    test_success = mp.Event()
+
+    def _tester():
+        # note: we run everything in a separate process to re-initialize all global states from scratch
+        # this helps us avoid undesirable side-effects when running multiple tests in sequence
+
+        loop = asyncio.get_event_loop()
+        protocol = loop.run_until_complete(DHTProtocol.create(
+            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
+
+        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+
+        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        assert recv_value_bytes is None and recv_expiration is None and len(nodes_found) == 0
+        assert all(loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        )), "peer rejected store"
+
+        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        assert len(nodes_found) == 0
+        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+
+        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
+        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
+        test_success.set()
+
+    tester = mp.Process(target=_tester, daemon=True)
+    tester.start()
+    tester.join()
+    assert test_success.is_set()
+    peer_proc.terminate()
+
+
 def run_node(node_id, peers, status_pipe: mp.Pipe):
 def run_node(node_id, peers, status_pipe: mp.Pipe):
     if asyncio.get_event_loop().is_running():
     if asyncio.get_event_loop().is_running():
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
@@ -132,7 +174,7 @@ def test_dht():
         # note: we run everything in a separate process to re-initialize all global states from scratch
         # note: we run everything in a separate process to re-initialize all global states from scratch
         # this helps us avoid undesirable side-effects when running multiple tests in sequence
         # this helps us avoid undesirable side-effects when running multiple tests in sequence
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
-        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5)))
+        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10))
 
 
         # test 1: find self
         # test 1: find self
         nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=me.node_id, k_nearest=1))
         nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=me.node_id, k_nearest=1))