Преглед на файлове

Merge pull request #17 from learning-at-home/api_reference

API reference (docs)
justheuristic преди 5 години
родител
ревизия
09d11f43fe

+ 4 - 0
docs/_static/fix_rtd.css

@@ -1,4 +1,8 @@
 /* work around https://github.com/snide/sphinx_rtd_theme/issues/149 */
 .rst-content table.field-list .field-body {
     padding-top: 8px;
+}
+/* unlimited page width */
+.wy-nav-content {
+    max-width: none;
 }

+ 3 - 2
docs/conf.py

@@ -66,8 +66,7 @@ templates_path = ['_templates']
 # The suffix(es) of source filenames.
 # You can specify multiple suffix as a list of string:
 #
-source_parsers = {'.md': CommonMarkParser}
-source_suffix = ['.rst', '.md']
+source_suffix = {'.rst': 'restructuredtext', '.md': 'markdown'}
 
 # The master toctree document.
 master_doc = 'index'
@@ -217,6 +216,8 @@ def setup(app):
         # 'enable_auto_doc_ref': True,
     }, True)
     app.add_transform(AutoStructify)
+    app.add_source_suffix('.md', 'markdown')
+    app.add_source_parser(CommonMarkParser)
 
 
 #  Resolve function for the linkcode extension.

+ 2 - 2
docs/index.rst

@@ -1,7 +1,7 @@
 ``learning@home::tesseract``
 ====================================
 
-Tesseract lets you train huge neural networks on computers provided by volunteers. Powered by pytorch
+Tesseract lets you train huge neural networks on computers provided by volunteers. Powered by pytorch.
 
 .. image:: _static/bug.gif
 
@@ -19,7 +19,7 @@ API documentation:
   :maxdepth: 2
 
   modules/client.rst
-  modules/runtime.md
+  modules/server.rst
 
 Indices and tables
 ==================

+ 2 - 2
docs/modules/client.rst

@@ -1,5 +1,5 @@
-tesseract.client
-================
+``tesseract.client``
+====================
 
 .. automodule:: tesseract.client
 

+ 0 - 3
docs/modules/runtime.md

@@ -1,3 +0,0 @@
-# Runtime 
-
-TODO i explain runtime

+ 25 - 0
docs/modules/server.rst

@@ -0,0 +1,25 @@
+``tesseract.server & runtime``
+========================================
+
+.. automodule:: tesseract.server
+
+.. currentmodule:: tesseract.server
+
+.. autoclass:: TesseractServer
+   :members:
+   :member-order: bysource
+
+.. currentmodule:: tesseract.runtime
+
+.. autoclass:: TesseractRuntime
+    :members:
+    :member-order: bysource
+
+
+.. autoclass:: ExpertBackend
+    :members: forward, backward, apply_gradients, get_info, get_pools
+    :member-order: bysource
+
+.. autoclass:: TaskPool
+    :members: submit_task, form_batch, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty
+    :member-order: bysource

+ 6 - 1
docs/user/quickstart.md

@@ -1 +1,6 @@
-# Quick start
+# Quick start [nothing here yet]
+
+This will eventually become a tutorial on how to host a tesseract node or connect to an existing node.
+
+![img](https://media.giphy.com/media/3oz8xtBx06mcZWoNJm/giphy.gif)
+

+ 6 - 0
tesseract/network/__init__.py

@@ -1,6 +1,7 @@
 import asyncio
 import datetime
 import multiprocessing as mp
+import warnings
 from typing import Tuple, List, Optional
 
 from kademlia.network import Server
@@ -33,6 +34,11 @@ class TesseractNetwork(mp.Process):
             method, args, kwargs = self._pipe.recv()
             getattr(self, method)(*args, **kwargs)
 
+    def shutdown(self) -> None:
+        """ Shuts down the network process """
+        warnings.warn("TODO shutdown network gracefully")
+        self.terminate()
+
     def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
         """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
         future, _future = SharedFuture.make_pair()

+ 22 - 7
tesseract/runtime/__init__.py

@@ -13,14 +13,29 @@ from .task_pool import TaskPool, TaskPoolBase
 
 
 class TesseractRuntime(threading.Thread):
+    """
+    A group of processes that processes incoming requests for multiple experts on a shared device.
+    TesseractRuntime is usually created and managed by TesseractServer, humans need not apply.
+
+    For debugging, you can start runtime manually with .start() or .run()
+
+    >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)}
+    >>> runtime = TesseractRuntime(expert_backends)
+    >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
+    >>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
+    >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs)
+    >>> print("Returned:", future.result())
+    >>> runtime.shutdown()
+
+    :param expert_backends: a dict [expert uid -> ExpertBackend]
+    :param prefetch_batches: form up to this many batches in advance
+    :param start: start runtime immediately (at the end of __init__)
+    :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
+    :param device: if specified, moves all experts and data to this device via .to(device=device).
+      If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
+    """
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
                  device: torch.device = None):
-        """
-        A group of processes that process tasks for multiple experts on a shared device
-        :param expert_backends: a dict [expert uid -> ExpertBackend]
-        :param prefetch_batches: generate up to this many batches in advance
-        :param start: start runtime immediately (at the end of __init__)
-        """
         super().__init__()
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
@@ -52,7 +67,7 @@ class TesseractRuntime(threading.Thread):
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
 
     def shutdown(self):
-        """ Trigger runtime to terminate, process-save """
+        """ Gracefully terminate a running runtime. """
         self.ready.clear()
         self.shutdown_send.send(self.SHUTDOWN_TRIGGER)  # trigger background thread to shutdown
         for pool in self.pools:

+ 59 - 25
tesseract/runtime/expert_backend.py

@@ -8,32 +8,32 @@ from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProt
 
 
 class ExpertBackend(nn.Module):
+    """
+    ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with TesseractRuntime
+    By default, ExpertBackend handles three types of requests:
+
+     - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
+     - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
+     - get_info - return expert metadata. Not batched.
+
+    :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
+
+        - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
+        - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
+        - We recommend using experts that are ~invariant to the order in which they process batches
+
+    :param opt: torch optimizer to be applied on every backward call
+    :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
+    :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
+    :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
+    :param kwargs: extra parameters to be forwarded into TaskPool.__init__
+    """
+
     def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
                  args_schema: Tuple[BatchTensorProto, ...] = None,
                  kwargs_schema: Dict[str, BatchTensorProto] = None,
                  outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
                  **kwargs):
-        """
-        ExpertBackend implements how a given expert processes tasks.
-        By default, there are two tasks:
-         * forward receives inputs and produces outputs
-         * backward receives gradients w.r.t. outputs, computes gradients w.r.t. inputs and trains the expert
-
-        All incoming tasks are grouped by type (forward/backward) and sent into the corresponding pool,
-        where tasks are grouped into minibatches and prepared for processing on device;
-        The results are dispatched to task authors with SharedFuture.set_result.
-
-        :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
-            * Experts must always receive the same set of *args and **kwargs and produce output tensors of same type
-            * All *args, **kwargs and outputs must be *tensors* where 0-th dimension represents to batch size
-            * We recommend using experts that are ~invariant to the order in which they process batches
-
-        :param opt: torch optimizer to be applied on every backward call
-        :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
-        :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
-        :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
-        :param kwargs: extra parameters to be forwarded into TaskPool.__init__
-        """
         super().__init__()
         self.expert, self.opt, self.name = expert, opt, name
 
@@ -56,15 +56,43 @@ class ExpertBackend(nn.Module):
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        """
+        Apply forward pass to an aggregated batch of requests. Used by TesseractRuntime, do not call this manually;
+        To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
+
+        Subclassing:
+           This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
+
+           It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
+
+           .. todo state - we recommend stateless but you can save state if you want. disable batchnorm track running stats
+
+        """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
 
         with torch.no_grad():
             outputs = self.expert(*args, **kwargs)
 
-        # Note: TaskPool requires function to accept and return a **list** of values, we pack/unpack it on client side
+        # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
 
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        """
+        Apply backward pass to an aggregated batch of requests. Used by TesseractRuntime, do not call this manually
+        To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
+
+        Subclassing:
+           This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
+
+           It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;
+
+           TesseractRuntime doesn't guarantee that backward will be performed in the same order and for the same data
+           as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
+
+           .. todo state, randomness, etc
+
+           Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
+        """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
         with torch.enable_grad():
@@ -87,12 +115,18 @@ class ExpertBackend(nn.Module):
                      for x in nested_flatten((args, kwargs)))
 
     def apply_gradients(self) -> None:
+        """
+        Train the expert for a single step. This method is called by ``ExpertBackend.backward`` after computing gradients.
+        """
         self.opt.step()
         self.opt.zero_grad()
 
-    def get_pools(self) -> Sequence[TaskPool]:
-        return self.forward_pool, self.backward_pool
-
     def get_info(self) -> Dict[str, Any]:
+        """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
         return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
                     keyword_names=tuple(self.kwargs_schema.keys()))
+
+    def get_pools(self) -> Sequence[TaskPool]:
+        """ return all pools that should be processed by ``TesseractRuntime`` """
+        return self.forward_pool, self.backward_pool
+

+ 15 - 12
tesseract/runtime/task_pool.py

@@ -54,21 +54,23 @@ class TaskPoolBase(mp.Process):
 
 
 class TaskPool(TaskPoolBase):
+    """
+    Request aggregator that accepts processing requests, groups them into batches, waits for TesseractRuntime
+    to process these batches and dispatches results back to request sources. Operates as a background process.
+
+    :param process_func: function to be applied to every formed batch; called by TesseractRuntime
+        Note that process_func should accept only \*args Tensors and return a flat tuple of Tensors
+    :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+    :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+    :param timeout: wait for a subsequent task for at most this many seconds
+    :param pool_size: store at most this many unprocessed tasks in a queue
+    :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
+    :param uid: pool identifier used for shared array allocation
+    :param start: if True, start automatically at the end of __init__
+    """
 
     def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
                  timeout=None, pool_size=None, prefetch_batches=1, uid=None, start=False):
-        """
-        Naive implementation of task pool that forms batch from earliest submitted tasks
-        :param process_func: function to be applied to every formed batch; called by TesseractRuntime
-            Note: process_func should accept only *args Tensors and return a list of output Tensors
-        :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
-        :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
-        :param timeout: wait for a subsequent task for at most this many seconds
-        :param pool_size: store at most this many unprocessed tasks in a queue
-        :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
-        :param uid: pool identifier used for shared array allocation
-        :param start: if True, start automatically at the end of __init__
-        """
 
         super().__init__(process_func)
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
@@ -88,6 +90,7 @@ class TaskPool(TaskPoolBase):
             self.start()
 
     def submit_task(self, *args: torch.Tensor) -> Future:
+        """ Add task to this pool's queue, return Future for its output """
         future1, future2 = SharedFuture.make_pair()
         self.tasks.put(Task(future1, args))
         self.undispatched_task_timestamps.put(time.time())

+ 55 - 6
tesseract/server/__init__.py

@@ -3,7 +3,6 @@ import os
 import threading
 from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
 from typing import Dict
-from warnings import warn
 
 from .connection_handler import handle_connection
 from .network_handler import NetworkHandlerThread
@@ -12,6 +11,28 @@ from ..runtime import TesseractRuntime, ExpertBackend
 
 
 class TesseractServer(threading.Thread):
+    """
+    TesseractServer allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
+    After creation, a server should be started: see TesseractServer.run or TesseractServer.run_in_background.
+
+    A working server does 3 things:
+     - processes incoming forward/backward requests via TesseractRuntime (created by the server)
+     - publishes updates to expert status every :update_period: seconds
+     - follows orders from TesseractController - if it exists
+
+    :type network: TesseractNetwork or None. Server with network=None will NOT be visible from DHT,
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, host=IPADDR, port=PORT).
+    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param addr: server's network address that determines how it can be accessed. Default is local connections only.
+    :param port: port to which server listens for requests such as expert forward or backward pass.
+    :param conn_handler_processes: maximum number of simultaneous requests. Please note that the default value of 1
+        if too small for normal functioning, we recommend 4 handlers per expert backend.
+    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
+        if network is None, this parameter is ignored.
+    :param start: if True, the server will immediately start as a background thread and returns control after server
+        is ready (see .ready below)
+    """
+
     def __init__(self, network: TesseractNetwork, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
                  port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
                  **kwargs):
@@ -22,9 +43,13 @@ class TesseractServer(threading.Thread):
         self.runtime = TesseractRuntime(self.experts, **kwargs)
 
         if start:
-            self.start()
+            self.run_in_background(await_ready=True)
 
     def run(self):
+        """
+        Starts TesseractServer in the current thread. Initializes network if necessary, starts connection handlers,
+        runs TesseractRuntime (self.runtime) to process incoming requests.
+        """
         if self.network:
             if not self.network.is_alive():
                 self.network.start()
@@ -44,8 +69,26 @@ class TesseractServer(threading.Thread):
         if self.network:
             network_thread.join()
 
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts TesseractServer in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("TesseractServer didn't notify .ready in {timeout} seconds")
+
     @property
-    def ready(self):
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
     def _create_connection_handlers(self, num_handlers):
@@ -60,11 +103,17 @@ class TesseractServer(threading.Thread):
         return processes
 
     def shutdown(self):
-        """ Gracefully terminate a tesseract server, process-safe """
-        self.runtime.shutdown()
+        """
+        Gracefully terminate a tesseract server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
         for process in self.conn_handlers:
             process.terminate()
-        warn("TODO shutdown network")
+        self.runtime.shutdown()
+        if self.network is not None:
+            self.network.shutdown()
 
 
 def socket_loop(sock, experts):