فهرست منبع

Reuse only successful DHT search results (#130)

* Reuse only successful DHT search results

* Fix typos in docs

* Remove unnecessary assertion message
(pytest displays unequal values on failure anyway)

* SortedList -> SortedSet, delete by key
Max Ryabinin 4 سال پیش
والد
کامیت
0ed7b46bb6
5فایلهای تغییر یافته به همراه26 افزوده شده و 21 حذف شده
  1. 2 2
      docs/user/acknowledgements.md
  2. 5 5
      docs/user/quickstart.md
  3. 1 1
      hivemind/__init__.py
  4. 17 11
      hivemind/dht/node.py
  5. 1 2
      tests/test_dht_node.py

+ 2 - 2
docs/user/acknowledgements.md

@@ -18,9 +18,9 @@ We kindly thank (in random order)
 
 
 We also want to reference several projects that have similar ideas in mind:
 We also want to reference several projects that have similar ideas in mind:
 
 
-* [BitTensor](https://github.com/opentensor/BitTensor) - a decentralized deep learning ecosystem with with incentive
+* [BitTensor](https://github.com/opentensor/BitTensor) - a decentralized deep learning ecosystem with incentive
  mechanism. Like hivemind, but peers are getting rewarded for their contribution to other peers.
  mechanism. Like hivemind, but peers are getting rewarded for their contribution to other peers.
   _(note: as of 26.08.2020 the project is in the early stages development)_.
   _(note: as of 26.08.2020 the project is in the early stages development)_.
-* [GShard](https://arxiv.org/abs/2006.16668) - a paper by Dmitry Lepikhin et al that demonstrate the effectiveness
+* [GShard](https://arxiv.org/abs/2006.16668) - a paper by Dmitry Lepikhin et al. that demonstrate the effectiveness
   of huge Mixture-of-Experts models on conventional hpc hardware. Those guys train models 4 times the size of GPT-3 on thousands of TPUv3.
   of huge Mixture-of-Experts models on conventional hpc hardware. Those guys train models 4 times the size of GPT-3 on thousands of TPUv3.
 * Also doing research in decentralized deep learning? Let us know!
 * Also doing research in decentralized deep learning? Let us know!

+ 5 - 5
docs/user/quickstart.md

@@ -11,16 +11,16 @@ You can also install the bleeding edge version from github:
 ```
 ```
 git clone https://github.com/learning-at-home/hivemind
 git clone https://github.com/learning-at-home/hivemind
 cd hivemind
 cd hivemind
-python setup.py install
+pip install .
 ```
 ```
 
 
-You can also install it in editable mode with `python setup.py develop`.
+You can also install it in the editable mode with `pip install -e .`.
 
 
 __Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
 __Note:__ we currently recommend installing hivemind from github (i.e. not pip) as it can run RemoteMixtureOfExperts faster by an order of magnitude. These changes will only reach PyPI in v0.9.0 release.
 
 
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
 * __Dependencies:__ Hivemind requires python 3.7+ (3.8 is recommended), it will install [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) automatically; 
-* __OS support:__ Linux and Mac OS should [just work](https://github.com/learning-at-home/hivemind/issues).
-We do not officially support Windows, but you are welcome to try and contribute your windows build :)
+* __OS support:__ Linux and macOS should [just work](https://github.com/learning-at-home/hivemind/issues).
+We do not officially support Windows, but you are welcome to contribute your windows build :)
 
 
 
 
 #### Host a server
 #### Host a server
@@ -158,7 +158,7 @@ out.sum().backward()  # backward pass
 ```
 ```
 
 
 When called, expert1 will submit a request to the corresponding server (which you created above) and return
 When called, expert1 will submit a request to the corresponding server (which you created above) and return
- the outputs tensor(s) or raise an exception. During backward, pytorch will submit the backward requests
+ the output tensor(s) or raise an exception. During backward, pytorch will submit the backward requests
  for the experts as they appear in the computation graph.
  for the experts as they appear in the computation graph.
  
  
 By default, the experts will automatically update their parameters with one step of SGD after each backward pass.
 By default, the experts will automatically update their parameters with one step of SGD after each backward pass.

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.16'
+__version__ = '0.8.17'

+ 17 - 11
hivemind/dht/node.py

@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
 from functools import partial
 from functools import partial
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 
 
-from sortedcontainers import SortedList
+from sortedcontainers import SortedSet
 
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
@@ -65,7 +65,7 @@ class DHTNode:
     # fmt:off
     # fmt:off
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
-    cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
+    cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     # fmt:on
     # fmt:on
 
 
@@ -115,7 +115,7 @@ class DHTNode:
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
 
 
         self.reuse_get_requests = reuse_get_requests
         self.reuse_get_requests = reuse_get_requests
-        self.pending_get_requests = defaultdict(partial(SortedList, key=lambda _res: - _res.sufficient_expiration_time))
+        self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
 
 
         # caching policy
         # caching policy
         self.refresh_timeout = refresh_timeout
         self.refresh_timeout = refresh_timeout
@@ -468,14 +468,17 @@ class DHTNode:
                 raise e
                 raise e
 
 
     def _reuse_finished_search_result(self, finished: _SearchState):
     def _reuse_finished_search_result(self, finished: _SearchState):
-        search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
-        expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
-        concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
-        # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time
-        while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
-            concurrent_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
-            concurrent_requests[-1].finish_search()
-            concurrent_requests.pop(-1)
+        pending_requests = self.pending_get_requests[finished.key_id]
+        if finished.found_something:
+            search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
+            expiration_time_threshold = max(finished.expiration_time, finished.sufficient_expiration_time)
+            # note: pending_requests is sorted in the order of descending sufficient_expiration_time
+            while pending_requests and expiration_time_threshold >= pending_requests[-1].sufficient_expiration_time:
+                pending_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
+                pending_requests[-1].finish_search()
+                pending_requests.pop()
+        else:
+            pending_requests.discard(finished)
 
 
     def _trigger_cache_refresh(self, search: _SearchState):
     def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
@@ -613,3 +616,6 @@ class _SearchState:
     def __lt__(self, other: _SearchState):
     def __lt__(self, other: _SearchState):
         """ _SearchState instances will be sorted by their target expiration time """
         """ _SearchState instances will be sorted by their target expiration time """
         return self.sufficient_expiration_time < other.sufficient_expiration_time
         return self.sufficient_expiration_time < other.sufficient_expiration_time
+
+    def __hash__(self):
+        return hash(self.key_id)

+ 1 - 2
tests/test_dht_node.py

@@ -136,8 +136,7 @@ def test_empty_table():
         protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
         protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
     assert len(nodes_found) == 0
-    assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-        f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+    assert recv_value == value and recv_expiration == expiration
 
 
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None