Эх сурвалжийг харах

Reuse only successful DHT search results (#130)

* Reuse only successful DHT search results

* Fix typos in docs

* Remove unnecessary assertion message
(pytest displays unequal values on failure anyway)

* SortedList -> SortedSet, delete by key
Max Ryabinin 4 жил өмнө
parent
commit
0ed7b46bb6

+ 2 - 2
docs/user/acknowledgements.md

@@ -18,9 +18,9 @@ We kindly thank (in random order)
 
 We also want to reference several projects that have similar ideas in mind:
 
-* [BitTensor](https://github.com/opentensor/BitTensor) - a decentralized deep learning ecosystem with with incentive
+* [BitTensor](https://github.com/opentensor/BitTensor) - a decentralized deep learning ecosystem with incentive
  mechanism. Like hivemind, but peers are getting rewarded for their contribution to other peers.
   _(note: as of 26.08.2020 the project is in the early stages development)_.
-* [GShard](https://arxiv.org/abs/2006.16668) - a paper by Dmitry Lepikhin et al that demonstrate the effectiveness
+* [GShard](https://arxiv.org/abs/2006.16668) - a paper by Dmitry Lepikhin et al. that demonstrate the effectiveness
   of huge Mixture-of-Experts models on conventional hpc hardware. Those guys train models 4 times the size of GPT-3 on thousands of TPUv3.
 * Also doing research in decentralized deep learning? Let us know!

+ 5 - 5
docs/user/quickstart.md

@@ -11,16 +11,16 @@ You can also install the bleeding edge version from github:
 ```
 git clone https://github.com/learning-at-home/hivemind
 cd hivemind
-python setup.py install
+pip install .
 ```
 
-You can also install it in editable mode with `python setup.py develop`.
+You can also install it in the editable mode with `pip install -e .`.
 
 __Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
 
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
-* __OS support:__ Linux and Mac OS should [just work](https://github.com/learning-at-home/hivemind/issues).
-We do not officially support Windows, but you are welcome to try and contribute your windows build :)
+* __OS support:__ Linux and macOS should [just work](https://github.com/learning-at-home/hivemind/issues).
+We do not officially support Windows, but you are welcome to contribute your windows build :)
 
 
 #### Host a server
@@ -158,7 +158,7 @@ out.sum().backward()  # backward pass
 ```
 
 When called, expert1 will submit a request to the corresponding server (which you created above) and return
- the outputs tensor(s) or raise an exception. During backward, pytorch will submit the backward requests
+ the output tensor(s) or raise an exception. During backward, pytorch will submit the backward requests
  for the experts as they appear in the computation graph.
  
 By default, the experts will automatically update their parameters with one step of SGD after each backward pass.

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.16'
+__version__ = '0.8.17'

+ 17 - 11
hivemind/dht/node.py

@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
 from functools import partial
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 
-from sortedcontainers import SortedList
+from sortedcontainers import SortedSet
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
@@ -65,7 +65,7 @@ class DHTNode:
     # fmt:off
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
-    cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
+    cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     # fmt:on
 
@@ -115,7 +115,7 @@ class DHTNode:
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
 
         self.reuse_get_requests = reuse_get_requests
-        self.pending_get_requests = defaultdict(partial(SortedList, key=lambda _res: - _res.sufficient_expiration_time))
+        self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
 
         # caching policy
         self.refresh_timeout = refresh_timeout
@@ -468,14 +468,17 @@ class DHTNode:
                 raise e
 
     def _reuse_finished_search_result(self, finished: _SearchState):
-        search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
-        expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
-        concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
-        # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time
-        while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
-            concurrent_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
-            concurrent_requests[-1].finish_search()
-            concurrent_requests.pop(-1)
+        pending_requests = self.pending_get_requests[finished.key_id]
+        if finished.found_something:
+            search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
+            expiration_time_threshold = max(finished.expiration_time, finished.sufficient_expiration_time)
+            # note: pending_requests is sorted in the order of descending sufficient_expiration_time
+            while pending_requests and expiration_time_threshold >= pending_requests[-1].sufficient_expiration_time:
+                pending_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
+                pending_requests[-1].finish_search()
+                pending_requests.pop()
+        else:
+            pending_requests.discard(finished)
 
     def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
@@ -613,3 +616,6 @@ class _SearchState:
     def __lt__(self, other: _SearchState):
         """ _SearchState instances will be sorted by their target expiration time """
         return self.sufficient_expiration_time < other.sufficient_expiration_time
+
+    def __hash__(self):
+        return hash(self.key_id)

+ 1 - 2
tests/test_dht_node.py

@@ -136,8 +136,7 @@ def test_empty_table():
         protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
-    assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-        f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+    assert recv_value == value and recv_expiration == expiration
 
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None