Browse Source

Update quickstart.md - add note about faster beam search (#113)

* Update quickstart.md

* Update quickstart.md

* even more conservative caching: do not cache if we don't know value yet
justheuristic 4 năm trước cách đây
mục cha
commit
bc8ce59fd6
2 tập tin đã thay đổi với 11 bổ sung9 xóa
  1. 3 1
      docs/user/quickstart.md
  2. 8 8
      hivemind/dht/node.py

+ 3 - 1
docs/user/quickstart.md

@@ -16,6 +16,8 @@ python setup.py install
 
 
 You can also install it in editable mode with `python setup.py develop`.
 You can also install it in editable mode with `python setup.py develop`.
 
 
+__Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
+
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __OS support:__ Linux and Mac OS should [just work](https://github.com/learning-at-home/hivemind/issues).
 * __OS support:__ Linux and Mac OS should [just work](https://github.com/learning-at-home/hivemind/issues).
 We do not officially support Windows, but you are welcome to try and contribute your windows build :)
 We do not officially support Windows, but you are welcome to try and contribute your windows build :)
@@ -197,4 +199,4 @@ You can find more details on how MoE works in Section 2.3 of the [paper](https:/
 
 
 Congratulations, you've made it through the basic tutorial. Give yourself a pat on the back :)
 Congratulations, you've made it through the basic tutorial. Give yourself a pat on the back :)
 
 
-More advanced tutorials are coming soon :)
+More advanced tutorials are coming soon :)

+ 8 - 8
hivemind/dht/node.py

@@ -65,7 +65,7 @@ class DHTNode:
     """
     """
     # fmt:off
     # fmt:off
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
-    refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
+    chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     # fmt:on
     # fmt:on
@@ -76,7 +76,7 @@ class DHTNode:
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
-            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1,
+            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
@@ -102,6 +102,7 @@ class DHTNode:
         :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
         :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
           all concurrent get requests for the same key will reuse the procedure that is currently in progress
           all concurrent get requests for the same key will reuse the procedure that is currently in progress
         :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
+        :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
           if False, this node will refuse any incoming request, effectively being only a "client"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -111,7 +112,7 @@ class DHTNode:
         """
         """
         self = cls(_initialized_with_create=True)
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
-        self.num_replicas, self.num_workers = num_replicas, num_workers
+        self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
 
 
         self.reuse_get_requests = reuse_get_requests
         self.reuse_get_requests = reuse_get_requests
@@ -340,10 +341,9 @@ class DHTNode:
             self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
             self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
         elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
         elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
             rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
             rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
-            self.protocol.cache.store(key_id, rejected_value, rejected_expiration)  # can still be better than cache
             if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration:  # cache would be rejected
             if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration:  # cache would be rejected
                 self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
                 self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
-        else:  # stored a dictionary (or failed to store), either way, there can be other keys and we should update
+        elif is_dictionary and key_id in self.protocol.cache:  # there can be other keys and we should update
             for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
             for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
                 self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
                 self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
@@ -452,8 +452,8 @@ class DHTNode:
             self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
             self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
 
 
         asyncio.create_task(traverse_dht(
         asyncio.create_task(traverse_dht(
-            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
-            beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
+            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size,
+            num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
             found_callback=found_callback, await_all_tasks=False))
             found_callback=found_callback, await_all_tasks=False))
 
 
@@ -516,7 +516,7 @@ class DHTNode:
                 keys_to_refresh = {key_id}
                 keys_to_refresh = {key_id}
                 max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
                 max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
-                while self.cache_refresh_queue:
+                while self.cache_refresh_queue and len(keys_to_refresh) < self.chunk_size:
                     key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
                     key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
                     if nearest_refresh_time > current_time:
                     if nearest_refresh_time > current_time:
                         break
                         break