Explorar el Código

Update quickstart.md - add note about faster beam search (#113)

* Update quickstart.md

* Update quickstart.md

* even more conservative caching: do not cache if we don't know value yet
justheuristic hace 4 años
padre
commit
bc8ce59fd6
Se han modificado 2 ficheros con 11 adiciones y 9 borrados
  1. 3 1
      docs/user/quickstart.md
  2. 8 8
      hivemind/dht/node.py

+ 3 - 1
docs/user/quickstart.md

@@ -16,6 +16,8 @@ python setup.py install
 
 You can also install it in editable mode with `python setup.py develop`.
 
+__Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
+
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __OS support:__ Linux and Mac OS should [just work](https://github.com/learning-at-home/hivemind/issues).
 We do not officially support Windows, but you are welcome to try and contribute your windows build :)
@@ -197,4 +199,4 @@ You can find more details on how MoE works in Section 2.3 of the [paper](https:/
 
 Congratulations, you've made it through the basic tutorial. Give yourself a pat on the back :)
 
-More advanced tutorials are coming soon :)
+More advanced tutorials are coming soon :)

+ 8 - 8
hivemind/dht/node.py

@@ -65,7 +65,7 @@ class DHTNode:
     """
     # fmt:off
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
-    refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
+    chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     # fmt:on
@@ -76,7 +76,7 @@ class DHTNode:
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
-            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1,
+            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
@@ -102,6 +102,7 @@ class DHTNode:
         :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
           all concurrent get requests for the same key will reuse the procedure that is currently in progress
         :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
+        :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -111,7 +112,7 @@ class DHTNode:
         """
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
-        self.num_replicas, self.num_workers = num_replicas, num_workers
+        self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
 
         self.reuse_get_requests = reuse_get_requests
@@ -340,10 +341,9 @@ class DHTNode:
             self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
         elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
             rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
-            self.protocol.cache.store(key_id, rejected_value, rejected_expiration)  # can still be better than cache
             if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration:  # cache would be rejected
                 self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
-        else:  # stored a dictionary (or failed to store), either way, there can be other keys and we should update
+        elif is_dictionary and key_id in self.protocol.cache:  # there can be other keys and we should update
             for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
                 self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
@@ -452,8 +452,8 @@ class DHTNode:
             self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
 
         asyncio.create_task(traverse_dht(
-            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
-            beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
+            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size,
+            num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
             found_callback=found_callback, await_all_tasks=False))
 
@@ -516,7 +516,7 @@ class DHTNode:
                 keys_to_refresh = {key_id}
                 max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
-                while self.cache_refresh_queue:
+                while self.cache_refresh_queue and len(keys_to_refresh) < self.chunk_size:
                     key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
                     if nearest_refresh_time > current_time:
                         break