Browse Source

Dht documentation & cosmetic fixes (#47)

* DHT -> HivemindDHT

* Server -> HivemindServer

* Server -> HivemindServer

* minor docstring renames

* update scheme for rtfd.io

* add basic dht docs

* docstring fixes across dht and server

* fix url resolver

* miscellaneous sphinx compatibility fixes

* miscellaneous sphinx compatibility fixes

* rollback HivemindDHT -> DHT, HivemindServer -> Server
justheuristic 5 years ago
parent
commit
656cdf9eb6

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.pdf


BIN
docs/_static/dht.png


+ 2 - 1
docs/conf.py

@@ -240,7 +240,8 @@ def linkcode_resolve(domain, info):
     if domain != 'py' or not info['module']:
     if domain != 'py' or not info['module']:
         return None
         return None
     try:
     try:
-        filename = 'hivemind/%s#L%d-L%d' % find_source()
+        filename = '%s#L%d-L%d' % find_source()
     except Exception:
     except Exception:
         filename = info['module'].replace('.', '/') + '.py'
         filename = info['module'].replace('.', '/') + '.py'
+
     return "https://github.com/learning-at-home/hivemind/blob/%s/%s" % (branch, filename)
     return "https://github.com/learning-at-home/hivemind/blob/%s/%s" % (branch, filename)

+ 1 - 0
docs/index.rst

@@ -20,6 +20,7 @@ API documentation:
 
 
   modules/client.rst
   modules/client.rst
   modules/server.rst
   modules/server.rst
+  modules/dht.rst
 
 
 Indices and tables
 Indices and tables
 ==================
 ==================

+ 44 - 0
docs/modules/dht.rst

@@ -0,0 +1,44 @@
+``hidemind.dht``
+====================
+
+.. image:: ../_static/dht.png
+   :width: 800
+
+.. automodule:: hivemind.dht
+
+.. currentmodule:: hivemind.dht
+
+
+.. autoclass:: DHT
+   :members:
+   :exclude-members: make_key
+   :member-order: bysource
+
+.. autoclass:: DHTNode
+   :members:
+   :member-order: bysource
+
+.. currentmodule:: hivemind.dht.protocol
+
+.. autoclass:: KademliaProtocol
+   :members:
+   :member-order: bysource
+
+.. currentmodule:: hivemind.dht.routing
+
+.. autoclass:: RoutingTable
+   :members:
+   :member-order: bysource
+
+.. autoclass:: KBucket
+   :members:
+   :member-order: bysource
+
+.. autoclass:: DHTID
+   :members:
+   :exclude-members: HASH_FUNC
+   :member-order: bysource
+
+.. currentmodule:: hivemind.dht.search
+
+.. autofunction:: traverse_dht

+ 1 - 1
docs/modules/server.rst

@@ -21,5 +21,5 @@
     :member-order: bysource
     :member-order: bysource
 
 
 .. autoclass:: TaskPool
 .. autoclass:: TaskPool
-    :members: submit_task, form_batch, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty
+    :members: submit_task, iterate_minibatches, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty
     :member-order: bysource
     :member-order: bysource

+ 2 - 2
hivemind/client/expert.py

@@ -16,8 +16,8 @@ class RemoteExpert(nn.Module):
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
 
     :param uid: unique expert identifier
     :param uid: unique expert identifier
-    :param host: hostname where Server operates
-    :param port: port to which Server listens
+    :param host: hostname where server operates
+    :param port: port to which server listens
     """
     """
 
 
     def __init__(self, uid, host='127.0.0.1', port=8080):
     def __init__(self, uid, host='127.0.0.1', port=8080):

+ 7 - 1
hivemind/client/moe.py

@@ -147,7 +147,13 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
     def compute_expert_scores(
     def compute_expert_scores(
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
-        """ TODO(jheuristic) docstring here """
+        """
+        Compute scores for each expert by adding up grid scores, autograd-friendly
+        :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
+        :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
+        :returns: a tensor of scores, float32[batch_size, k]
+        :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
+        """
         expert_counts = list(map(len, batch_experts))
         expert_counts = list(map(len, batch_experts))
         batch_size = len(batch_experts)
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)
         max_num_experts = max(expert_counts)

+ 5 - 1
hivemind/dht/__init__.py

@@ -22,8 +22,9 @@ from ..utils import SharedFuture, find_open_port, Hostname, Port, run_in_backgro
 class DHT(mp.Process):
 class DHT(mp.Process):
     """
     """
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
+
     :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
     :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
-    :param port: a port where DHT will listen to incoming connections. Defaults to hivemind.utils.find_open_port
+    :param port: a port where DHT node will listen to incoming connections. Defaults to hivemind.utils.find_open_port
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param node_params: any other params will be forwarded to DHTNode upon creation
     :param node_params: any other params will be forwarded to DHTNode upon creation
@@ -45,6 +46,7 @@ class DHT(mp.Process):
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
 
 
     def run(self) -> None:
     def run(self) -> None:
+        """ Serve DHT forever. This function will not return until DHT node is shut down """
         if asyncio.get_event_loop().is_running():
         if asyncio.get_event_loop().is_running():
             asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
             asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         loop = asyncio.new_event_loop()
         loop = asyncio.new_event_loop()
@@ -102,6 +104,7 @@ class DHT(mp.Process):
     def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
     def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         """
         Make experts available to DHT; update timestamps if already available
         Make experts available to DHT; update timestamps if already available
+
         :param uids: a list of expert ids to update
         :param uids: a list of expert ids to update
         :param addr: hostname that can be used to call this expert
         :param addr: hostname that can be used to call this expert
         :param port: port that can be used to call this expert
         :param port: port that can be used to call this expert
@@ -139,6 +142,7 @@ class DHT(mp.Process):
     def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
     def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
         """
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
+
         :param prefixes: a list of uid prefixes ordered from highest to lowest priority
         :param prefixes: a list of uid prefixes ordered from highest to lowest priority
         :param k: return at most *this many* active prefixes
         :param k: return at most *this many* active prefixes
         :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
         :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k

+ 10 - 6
hivemind/dht/node.py

@@ -37,10 +37,12 @@ class DHTNode:
     Informally, dht nodes always prefer values with higher expiration_time and may delete any value past its expiration.
     Informally, dht nodes always prefer values with higher expiration_time and may delete any value past its expiration.
 
 
     Formally, DHTNode follows this contract:
     Formally, DHTNode follows this contract:
-      - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
-       unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
-      - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
-       if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+
+    - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
+      unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
+    - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
+      if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+
     """
     """
 
 
     def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
     def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
@@ -89,8 +91,9 @@ class DHTNode:
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
         """
         """
         Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
         Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
-        :note: this is a thin wrapper over dht.search.beam_search, look there for more details
+
         :returns: an ordered dictionary of [peer DHTID -> network Endpoint], ordered from nearest to farthest neighbor
         :returns: an ordered dictionary of [peer DHTID -> network Endpoint], ordered from nearest to farthest neighbor
+        :note: this is a thin wrapper over dht.search.traverse_dht, look there for more details
         """
         """
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
@@ -116,7 +119,8 @@ class DHTNode:
         """
         """
         Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
         Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
         Also cache (key, value, expiration_time) at all nodes you met along the way (see Section 2.1 end)
         Also cache (key, value, expiration_time) at all nodes you met along the way (see Section 2.1 end)
-        :return: True if store succeeds, False if it fails (due to no response or newer value)
+
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
         key_id = DHTID.generate(key)
         key_id = DHTID.generate(key)
         nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)
         nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)

+ 8 - 2
hivemind/dht/protocol.py

@@ -55,6 +55,7 @@ class KademliaProtocol(RPCProtocol):
                          expiration_time: DHTExpiration, in_cache: bool = False) -> Optional[bool]:
                          expiration_time: DHTExpiration, in_cache: bool = False) -> Optional[bool]:
         """
         """
         Ask a recipient to store (key, value) pair until expiration time or update their older value
         Ask a recipient to store (key, value) pair until expiration time or update their older value
+
         :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
         :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
         """
         """
         responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
         responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
@@ -69,6 +70,7 @@ class KademliaProtocol(RPCProtocol):
                       query_id_bytes: BinaryDHTID) -> Tuple[List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
                       query_id_bytes: BinaryDHTID) -> Tuple[List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
         """
         """
         Someone wants to find :key_node: in the DHT. Give him k nearest neighbors from our routing table
         Someone wants to find :key_node: in the DHT. Give him k nearest neighbors from our routing table
+
         :returns: a list of pairs (node_id, address) of :bucket_size: nearest to key_node according to XOR distance,
         :returns: a list of pairs (node_id, address) of :bucket_size: nearest to key_node according to XOR distance,
          also returns our own node id for routing table maintenance
          also returns our own node id for routing table maintenance
         """
         """
@@ -81,6 +83,7 @@ class KademliaProtocol(RPCProtocol):
         """
         """
         Ask a recipient to give you nearest neighbors to key_node. If recipient knows key_node directly,
         Ask a recipient to give you nearest neighbors to key_node. If recipient knows key_node directly,
          it will be returned as first of the neighbors; if recipient does not respond, return empty dict.
          it will be returned as first of the neighbors; if recipient does not respond, return empty dict.
+
         :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
         :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
         """
         """
         responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
         responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
@@ -97,8 +100,9 @@ class KademliaProtocol(RPCProtocol):
         """
         """
         Someone wants to find value corresponding to key. If we have the value, return the value and its expiration time
         Someone wants to find value corresponding to key. If we have the value, return the value and its expiration time
          Either way, return :bucket_size: nearest neighbors to that node.
          Either way, return :bucket_size: nearest neighbors to that node.
-        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
+
         :returns: (value or None if we have no value, nearest neighbors, our own dht id)
         :returns: (value or None if we have no value, nearest neighbors, our own dht id)
+        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
         """
         """
         maybe_value, maybe_expiration = self.storage.get(DHTID.from_bytes(key_bytes))
         maybe_value, maybe_expiration = self.storage.get(DHTID.from_bytes(key_bytes))
         cached_value, cached_expiration = self.cache.get(DHTID.from_bytes(key_bytes))
         cached_value, cached_expiration = self.cache.get(DHTID.from_bytes(key_bytes))
@@ -111,11 +115,12 @@ class KademliaProtocol(RPCProtocol):
             Tuple[Optional[DHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]:
             Tuple[Optional[DHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]:
         """
         """
         Ask a recipient to give you the value, if it has one, or nearest neighbors to your key.
         Ask a recipient to give you the value, if it has one, or nearest neighbors to your key.
+
         :returns: (optional value, optional expiration time, and neighbors)
         :returns: (optional value, optional expiration time, and neighbors)
          value: whatever was the latest value stored by the recipient with that key (see DHTNode contract)
          value: whatever was the latest value stored by the recipient with that key (see DHTNode contract)
          expiration time: expiration time of the returned value, None if no value was found
          expiration time: expiration time of the returned value, None if no value was found
          neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
          neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
-        Note: if no response, returns None, None, {}
+        :note: if no response, returns None, None, {}
         """
         """
         responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
         responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
         if responded:
         if responded:
@@ -128,6 +133,7 @@ class KademliaProtocol(RPCProtocol):
     async def update_routing_table(self, node_id: Optional[DHTID], addr: Endpoint, responded=True):
     async def update_routing_table(self, node_id: Optional[DHTID], addr: Endpoint, responded=True):
         """
         """
         This method is called on every incoming AND outgoing request to update the routing table
         This method is called on every incoming AND outgoing request to update the routing table
+
         :param addr: sender endpoint for incoming requests, recipient endpoint for outgoing requests
         :param addr: sender endpoint for incoming requests, recipient endpoint for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
         :param responded: for outgoing requests, this indicated whether recipient responded or not.

+ 9 - 2
hivemind/dht/routing.py

@@ -16,6 +16,7 @@ from ..utils import Endpoint, PickleSerializer
 class RoutingTable:
 class RoutingTable:
     """
     """
     A data structure that contains DHT peers bucketed according to their distance to node_id
     A data structure that contains DHT peers bucketed according to their distance to node_id
+
     :param node_id: node id used to measure distance
     :param node_id: node id used to measure distance
     :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
     :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
     :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
     :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
@@ -38,6 +39,7 @@ class RoutingTable:
     def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
     def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
         """
         """
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
+
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :note: KademliaProtocol calls this method for every incoming and outgoing request if there was a response.
         :note: KademliaProtocol calls this method for every incoming and outgoing request if there was a response.
           If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
           If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
@@ -81,6 +83,7 @@ class RoutingTable:
             self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
             self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
         """
         """
         Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
         Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
+
         :param query_id: find neighbors of this node
         :param query_id: find neighbors of this node
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param exclude: if True, results will not contain query_node_id even if it is in table
         :param exclude: if True, results will not contain query_node_id even if it is in table
@@ -147,6 +150,7 @@ class KBucket:
         """
         """
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
+
         :param node_id: dht node identifier that should be added or moved to the front of bucket
         :param node_id: dht node identifier that should be added or moved to the front of bucket
         :param addr: a pair of (hostname, port) associated with that node id
         :param addr: a pair of (hostname, port) associated with that node id
         :note: this function has a side-effect of resetting KBucket.last_updated time
         :note: this function has a side-effect of resetting KBucket.last_updated time
@@ -225,6 +229,7 @@ class DHTID(int):
     def generate(cls, source: Optional[Any] = None, nbits: int = 255):
     def generate(cls, source: Optional[Any] = None, nbits: int = 255):
         """
         """
         Generates random uid based on SHA1
         Generates random uid based on SHA1
+
         :param source: if provided, converts this value to bytes and uses it as input for hashing function;
         :param source: if provided, converts this value to bytes and uses it as input for hashing function;
             by default, generates a random dhtid from :nbits: random bits
             by default, generates a random dhtid from :nbits: random bits
         """
         """
@@ -249,11 +254,13 @@ class DHTID(int):
         return len(os.path.commonprefix(ids_bits))
         return len(os.path.commonprefix(ids_bits))
 
 
     def to_bytes(self, length=HASH_NBYTES, byteorder='big', *, signed=False) -> bytes:
     def to_bytes(self, length=HASH_NBYTES, byteorder='big', *, signed=False) -> bytes:
+        """ A standard way to serialize DHTID into bytes """
         return super().to_bytes(length, byteorder, signed=signed)
         return super().to_bytes(length, byteorder, signed=signed)
 
 
     @classmethod
     @classmethod
-    def from_bytes(self, bytes, byteorder='big', *, signed=False) -> DHTID:
-        return DHTID(super().from_bytes(bytes, byteorder=byteorder, signed=signed))
+    def from_bytes(cls, raw: bytes, byteorder='big', *, signed=False) -> DHTID:
+        """ reverse of to_bytes """
+        return DHTID(super().from_bytes(raw, byteorder=byteorder, signed=signed))
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}({hex(self)})"
         return f"{self.__class__.__name__}({hex(self)})"

+ 1 - 1
hivemind/dht/search.py

@@ -14,7 +14,7 @@ async def traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], k_near
     Approximate time complexity: O(T * log T) where T = (path_to_true_nearest + beam_size) * mean_num_neighbors
     Approximate time complexity: O(T * log T) where T = (path_to_true_nearest + beam_size) * mean_num_neighbors
 
 
     :param query_id: search query, find k_nearest neighbors of this DHTID
     :param query_id: search query, find k_nearest neighbors of this DHTID
-    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, *maybe_some_peers]
+    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
     :param k_nearest: find up to this many nearest neighbors. If there are less nodes in the DHT, return all nodes
     :param k_nearest: find up to this many nearest neighbors. If there are less nodes in the DHT, return all nodes
     :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
     :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
         Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.
         Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.

+ 1 - 0
hivemind/runtime/task_pool.py

@@ -97,6 +97,7 @@ class TaskPool(TaskPoolBase):
         return future2
         return future2
 
 
     def iterate_minibatches(self, *args, **kwargs):
     def iterate_minibatches(self, *args, **kwargs):
+        """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
         batch = []
         batch = []
         total_size = 0
         total_size = 0