Browse Source

Set default DHT num_workers = 4 (#342)

This change seems to speed up DHT get requests.

Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexander Borzunov 4 years ago
parent
commit
0b3dcf4da0

+ 5 - 5
hivemind/dht/__init__.py

@@ -23,7 +23,7 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.dht.node import DHTNode
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
@@ -43,7 +43,7 @@ class DHT(mp.Process):
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param max_workers: declare_experts and get_experts will use up to this many parallel workers
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
       (but no more than one per key)
       (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
     :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
@@ -62,7 +62,7 @@ class DHT(mp.Process):
         *,
         *,
         start: bool,
         start: bool,
         daemon: bool = True,
         daemon: bool = True,
-        max_workers: Optional[int] = None,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         record_validators: Iterable[RecordValidatorBase] = (),
         shutdown_timeout: float = 3,
         shutdown_timeout: float = 3,
         await_ready: bool = True,
         await_ready: bool = True,
@@ -81,7 +81,7 @@ class DHT(mp.Process):
             raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
             raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
         self.initial_peers = initial_peers
         self.initial_peers = initial_peers
         self.kwargs = kwargs
         self.kwargs = kwargs
-        self.max_workers = max_workers
+        self.num_workers = num_workers
 
 
         self._record_validator = CompositeValidator(record_validators)
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
@@ -106,7 +106,7 @@ class DHT(mp.Process):
             async def _run():
             async def _run():
                 self._node = await DHTNode.create(
                 self._node = await DHTNode.create(
                     initial_peers=self.initial_peers,
                     initial_peers=self.initial_peers,
-                    num_workers=self.max_workers or 1,
+                    num_workers=self.num_workers,
                     record_validator=self._record_validator,
                     record_validator=self._record_validator,
                     **self.kwargs,
                     **self.kwargs,
                 )
                 )

+ 6 - 2
hivemind/dht/node.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import dataclasses
 import dataclasses
+import os
 import random
 import random
 from collections import Counter, defaultdict
 from collections import Counter, defaultdict
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
@@ -38,6 +39,9 @@ from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithE
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+DEFAULT_NUM_WORKERS = int(os.getenv("HIVEMIND_DHT_NUM_WORKERS", 4))
+
+
 class DHTNode:
 class DHTNode:
     """
     """
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
@@ -110,7 +114,7 @@ class DHTNode:
         cache_refresh_before_expiry: float = 5,
         cache_refresh_before_expiry: float = 5,
         cache_on_store: bool = True,
         cache_on_store: bool = True,
         reuse_get_requests: bool = True,
         reuse_get_requests: bool = True,
-        num_workers: int = 1,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         chunk_size: int = 16,
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
         backoff_rate: float = 2.0,
@@ -154,7 +158,7 @@ class DHTNode:
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
           if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response

+ 3 - 3
hivemind/moe/client/beam_search.py

@@ -125,7 +125,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or dht.max_workers or beam_size
+        num_workers = num_workers or dht.num_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
             range(len(scores)), key=scores.__getitem__
@@ -206,7 +206,7 @@ class MoEBeamSearcher:
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float("inf")
         grid_size = grid_size or float("inf")
-        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         for prefix, found in dht_responses.items():
         for prefix, found in dht_responses.items():
@@ -270,7 +270,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
     ) -> List[RemoteExpert]:
     ) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+        num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -56,7 +56,7 @@ def declare_experts(
 async def _declare_experts(
 async def _declare_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
     for uid in uids:
@@ -89,7 +89,7 @@ async def _get_experts(
 ) -> List[Optional[RemoteExpert]]:
 ) -> List[Optional[RemoteExpert]]:
     if expiration_time is None:
     if expiration_time is None:
         expiration_time = get_dht_time()
         expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)