Forráskód Böngészése

Set default DHT num_workers = 4 (#342)

This change seems to speed up DHT get requests.

Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexander Borzunov 4 éve
szülő
commit
0b3dcf4da0

+ 5 - 5
hivemind/dht/__init__.py

@@ -23,7 +23,7 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 
 from multiaddr import Multiaddr
 
-from hivemind.dht.node import DHTNode
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
@@ -43,7 +43,7 @@ class DHT(mp.Process):
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param max_workers: declare_experts and get_experts will use up to this many parallel workers
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
       (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
@@ -62,7 +62,7 @@ class DHT(mp.Process):
         *,
         start: bool,
         daemon: bool = True,
-        max_workers: Optional[int] = None,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         shutdown_timeout: float = 3,
         await_ready: bool = True,
@@ -81,7 +81,7 @@ class DHT(mp.Process):
             raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
         self.initial_peers = initial_peers
         self.kwargs = kwargs
-        self.max_workers = max_workers
+        self.num_workers = num_workers
 
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
@@ -106,7 +106,7 @@ class DHT(mp.Process):
             async def _run():
                 self._node = await DHTNode.create(
                     initial_peers=self.initial_peers,
-                    num_workers=self.max_workers or 1,
+                    num_workers=self.num_workers,
                     record_validator=self._record_validator,
                     **self.kwargs,
                 )

+ 6 - 2
hivemind/dht/node.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import dataclasses
+import os
 import random
 from collections import Counter, defaultdict
 from dataclasses import dataclass, field
@@ -38,6 +39,9 @@ from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithE
 logger = get_logger(__name__)
 
 
+DEFAULT_NUM_WORKERS = int(os.getenv("HIVEMIND_DHT_NUM_WORKERS", 4))
+
+
 class DHTNode:
     """
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
@@ -110,7 +114,7 @@ class DHTNode:
         cache_refresh_before_expiry: float = 5,
         cache_on_store: bool = True,
         reuse_get_requests: bool = True,
-        num_workers: int = 1,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
@@ -154,7 +158,7 @@ class DHTNode:
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response

+ 3 - 3
hivemind/moe/client/beam_search.py

@@ -125,7 +125,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or dht.max_workers or beam_size
+        num_workers = num_workers or dht.num_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
@@ -206,7 +206,7 @@ class MoEBeamSearcher:
         num_workers: Optional[int] = None,
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float("inf")
-        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         for prefix, found in dht_responses.items():
@@ -270,7 +270,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
     ) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+        num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -56,7 +56,7 @@ def declare_experts(
 async def _declare_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
@@ -89,7 +89,7 @@ async def _get_experts(
 ) -> List[Optional[RemoteExpert]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)