浏览代码

Dht documentation & cosmetic fixes (#47)

* DHT -> HivemindDHT

* Server -> HivemindServer

* Server -> HivemindServer

* minor docstring renames

* update scheme for rtfd.io

* add basic dht docs

* docstring fixes across dht and server

* fix url resolver

* miscellaneous sphinx compatibility fixes

* miscellaneous sphinx compatibility fixes

* rollback HivemindDHT -> DHT, HivemindServer -> Server
justheuristic 5 年之前
父节点
当前提交
656cdf9eb6

二进制
docs/_static/dht.odp


二进制
docs/_static/dht.pdf


二进制
docs/_static/dht.png


+ 2 - 1
docs/conf.py

@@ -240,7 +240,8 @@ def linkcode_resolve(domain, info):
     if domain != 'py' or not info['module']:
         return None
     try:
-        filename = 'hivemind/%s#L%d-L%d' % find_source()
+        filename = '%s#L%d-L%d' % find_source()
     except Exception:
         filename = info['module'].replace('.', '/') + '.py'
+
     return "https://github.com/learning-at-home/hivemind/blob/%s/%s" % (branch, filename)

+ 1 - 0
docs/index.rst

@@ -20,6 +20,7 @@ API documentation:
 
   modules/client.rst
   modules/server.rst
+  modules/dht.rst
 
 Indices and tables
 ==================

+ 44 - 0
docs/modules/dht.rst

@@ -0,0 +1,44 @@
+``hidemind.dht``
+====================
+
+.. image:: ../_static/dht.png
+   :width: 800
+
+.. automodule:: hivemind.dht
+
+.. currentmodule:: hivemind.dht
+
+
+.. autoclass:: DHT
+   :members:
+   :exclude-members: make_key
+   :member-order: bysource
+
+.. autoclass:: DHTNode
+   :members:
+   :member-order: bysource
+
+.. currentmodule:: hivemind.dht.protocol
+
+.. autoclass:: KademliaProtocol
+   :members:
+   :member-order: bysource
+
+.. currentmodule:: hivemind.dht.routing
+
+.. autoclass:: RoutingTable
+   :members:
+   :member-order: bysource
+
+.. autoclass:: KBucket
+   :members:
+   :member-order: bysource
+
+.. autoclass:: DHTID
+   :members:
+   :exclude-members: HASH_FUNC
+   :member-order: bysource
+
+.. currentmodule:: hivemind.dht.search
+
+.. autofunction:: traverse_dht

+ 1 - 1
docs/modules/server.rst

@@ -21,5 +21,5 @@
     :member-order: bysource
 
 .. autoclass:: TaskPool
-    :members: submit_task, form_batch, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty
+    :members: submit_task, iterate_minibatches, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty
     :member-order: bysource

+ 2 - 2
hivemind/client/expert.py

@@ -16,8 +16,8 @@ class RemoteExpert(nn.Module):
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
     :param uid: unique expert identifier
-    :param host: hostname where Server operates
-    :param port: port to which Server listens
+    :param host: hostname where server operates
+    :param port: port to which server listens
     """
 
     def __init__(self, uid, host='127.0.0.1', port=8080):

+ 7 - 1
hivemind/client/moe.py

@@ -147,7 +147,13 @@ class RemoteMixtureOfExperts(nn.Module):
 
     def compute_expert_scores(
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
-        """ TODO(jheuristic) docstring here """
+        """
+        Compute scores for each expert by adding up grid scores, autograd-friendly
+        :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
+        :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch
+        :returns: a tensor of scores, float32[batch_size, k]
+        :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf
+        """
         expert_counts = list(map(len, batch_experts))
         batch_size = len(batch_experts)
         max_num_experts = max(expert_counts)

+ 5 - 1
hivemind/dht/__init__.py

@@ -22,8 +22,9 @@ from ..utils import SharedFuture, find_open_port, Hostname, Port, run_in_backgro
 class DHT(mp.Process):
     """
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
+
     :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
-    :param port: a port where DHT will listen to incoming connections. Defaults to hivemind.utils.find_open_port
+    :param port: a port where DHT node will listen to incoming connections. Defaults to hivemind.utils.find_open_port
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param node_params: any other params will be forwarded to DHTNode upon creation
@@ -45,6 +46,7 @@ class DHT(mp.Process):
             self.run_in_background(await_ready=True)
 
     def run(self) -> None:
+        """ Serve DHT forever. This function will not return until DHT node is shut down """
         if asyncio.get_event_loop().is_running():
             asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         loop = asyncio.new_event_loop()
@@ -102,6 +104,7 @@ class DHT(mp.Process):
     def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         Make experts available to DHT; update timestamps if already available
+
         :param uids: a list of expert ids to update
         :param addr: hostname that can be used to call this expert
         :param port: port that can be used to call this expert
@@ -139,6 +142,7 @@ class DHT(mp.Process):
     def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
+
         :param prefixes: a list of uid prefixes ordered from highest to lowest priority
         :param k: return at most *this many* active prefixes
         :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k

+ 10 - 6
hivemind/dht/node.py

@@ -37,10 +37,12 @@ class DHTNode:
     Informally, dht nodes always prefer values with higher expiration_time and may delete any value past its expiration.
 
     Formally, DHTNode follows this contract:
-      - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
-       unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
-      - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
-       if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+
+    - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
+      unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
+    - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
+      if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+
     """
 
     def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
@@ -89,8 +91,9 @@ class DHTNode:
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
         """
         Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
-        :note: this is a thin wrapper over dht.search.beam_search, look there for more details
+
         :returns: an ordered dictionary of [peer DHTID -> network Endpoint], ordered from nearest to farthest neighbor
+        :note: this is a thin wrapper over dht.search.traverse_dht, look there for more details
         """
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
@@ -116,7 +119,8 @@ class DHTNode:
         """
         Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
         Also cache (key, value, expiration_time) at all nodes you met along the way (see Section 2.1 end)
-        :return: True if store succeeds, False if it fails (due to no response or newer value)
+
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         key_id = DHTID.generate(key)
         nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)

+ 8 - 2
hivemind/dht/protocol.py

@@ -55,6 +55,7 @@ class KademliaProtocol(RPCProtocol):
                          expiration_time: DHTExpiration, in_cache: bool = False) -> Optional[bool]:
         """
         Ask a recipient to store (key, value) pair until expiration time or update their older value
+
         :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
         """
         responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
@@ -69,6 +70,7 @@ class KademliaProtocol(RPCProtocol):
                       query_id_bytes: BinaryDHTID) -> Tuple[List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
         """
         Someone wants to find :key_node: in the DHT. Give him k nearest neighbors from our routing table
+
         :returns: a list of pairs (node_id, address) of :bucket_size: nearest to key_node according to XOR distance,
          also returns our own node id for routing table maintenance
         """
@@ -81,6 +83,7 @@ class KademliaProtocol(RPCProtocol):
         """
         Ask a recipient to give you nearest neighbors to key_node. If recipient knows key_node directly,
          it will be returned as first of the neighbors; if recipient does not respond, return empty dict.
+
         :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
         """
         responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
@@ -97,8 +100,9 @@ class KademliaProtocol(RPCProtocol):
         """
         Someone wants to find value corresponding to key. If we have the value, return the value and its expiration time
          Either way, return :bucket_size: nearest neighbors to that node.
-        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
+
         :returns: (value or None if we have no value, nearest neighbors, our own dht id)
+        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
         """
         maybe_value, maybe_expiration = self.storage.get(DHTID.from_bytes(key_bytes))
         cached_value, cached_expiration = self.cache.get(DHTID.from_bytes(key_bytes))
@@ -111,11 +115,12 @@ class KademliaProtocol(RPCProtocol):
             Tuple[Optional[DHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]:
         """
         Ask a recipient to give you the value, if it has one, or nearest neighbors to your key.
+
         :returns: (optional value, optional expiration time, and neighbors)
          value: whatever was the latest value stored by the recipient with that key (see DHTNode contract)
          expiration time: expiration time of the returned value, None if no value was found
          neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
-        Note: if no response, returns None, None, {}
+        :note: if no response, returns None, None, {}
         """
         responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
         if responded:
@@ -128,6 +133,7 @@ class KademliaProtocol(RPCProtocol):
     async def update_routing_table(self, node_id: Optional[DHTID], addr: Endpoint, responded=True):
         """
         This method is called on every incoming AND outgoing request to update the routing table
+
         :param addr: sender endpoint for incoming requests, recipient endpoint for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param responded: for outgoing requests, this indicated whether recipient responded or not.

+ 9 - 2
hivemind/dht/routing.py

@@ -16,6 +16,7 @@ from ..utils import Endpoint, PickleSerializer
 class RoutingTable:
     """
     A data structure that contains DHT peers bucketed according to their distance to node_id
+
     :param node_id: node id used to measure distance
     :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
     :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
@@ -38,6 +39,7 @@ class RoutingTable:
     def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
         """
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
+
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :note: KademliaProtocol calls this method for every incoming and outgoing request if there was a response.
           If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
@@ -81,6 +83,7 @@ class RoutingTable:
             self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
         """
         Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
+
         :param query_id: find neighbors of this node
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param exclude: if True, results will not contain query_node_id even if it is in table
@@ -147,6 +150,7 @@ class KBucket:
         """
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
+
         :param node_id: dht node identifier that should be added or moved to the front of bucket
         :param addr: a pair of (hostname, port) associated with that node id
         :note: this function has a side-effect of resetting KBucket.last_updated time
@@ -225,6 +229,7 @@ class DHTID(int):
     def generate(cls, source: Optional[Any] = None, nbits: int = 255):
         """
         Generates random uid based on SHA1
+
         :param source: if provided, converts this value to bytes and uses it as input for hashing function;
             by default, generates a random dhtid from :nbits: random bits
         """
@@ -249,11 +254,13 @@ class DHTID(int):
         return len(os.path.commonprefix(ids_bits))
 
     def to_bytes(self, length=HASH_NBYTES, byteorder='big', *, signed=False) -> bytes:
+        """ A standard way to serialize DHTID into bytes """
         return super().to_bytes(length, byteorder, signed=signed)
 
     @classmethod
-    def from_bytes(self, bytes, byteorder='big', *, signed=False) -> DHTID:
-        return DHTID(super().from_bytes(bytes, byteorder=byteorder, signed=signed))
+    def from_bytes(cls, raw: bytes, byteorder='big', *, signed=False) -> DHTID:
+        """ reverse of to_bytes """
+        return DHTID(super().from_bytes(raw, byteorder=byteorder, signed=signed))
 
     def __repr__(self):
         return f"{self.__class__.__name__}({hex(self)})"

+ 1 - 1
hivemind/dht/search.py

@@ -14,7 +14,7 @@ async def traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], k_near
     Approximate time complexity: O(T * log T) where T = (path_to_true_nearest + beam_size) * mean_num_neighbors
 
     :param query_id: search query, find k_nearest neighbors of this DHTID
-    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, *maybe_some_peers]
+    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
     :param k_nearest: find up to this many nearest neighbors. If there are less nodes in the DHT, return all nodes
     :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
         Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.

+ 1 - 0
hivemind/runtime/task_pool.py

@@ -97,6 +97,7 @@ class TaskPool(TaskPoolBase):
         return future2
 
     def iterate_minibatches(self, *args, **kwargs):
+        """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
         batch = []
         total_size = 0