浏览代码

remove dependency on run_and_await_k, rename GatingFunction to RemoteMixtureOfExperts

justheuristic 5 年之前
父节点
当前提交
5016002186
共有 5 个文件被更改,包括 32 次插入13 次删除
  1. 1 1
      README.md
  2. 1 1
      docs/modules/client.rst
  3. 1 1
      tesseract/client/__init__.py
  4. 28 9
      tesseract/client/remote_moe.py
  5. 1 1
      tesseract/utils/threading.py

+ 1 - 1
README.md

@@ -33,7 +33,7 @@ do something complex with it, please contact us by opening an issue (less prefer
 
 - **`RemoteExpert`**(`tesseract/client/remote_expert.py`) behaves like a pytorch
   module with autograd support but actually sends request to a remote runtime.
-- **`GatingFunction`**(`tesseract/client/gating_function.py`) finds best experts
+- **`RemoteMixtureOfExperts`**(`tesseract/client/remote_moe.py`) finds best experts
   for a given input and either returns them as `RemoteExpert` or applies them
   right away.
 

+ 1 - 1
docs/modules/client.rst

@@ -14,6 +14,6 @@
 .. autoclass:: RemoteExpert
    :members: forward
 
-.. autoclass:: GatingFunction
+.. autoclass:: RemoteMixtureOfExperts
    :members:
    :member-order: bysource

+ 1 - 1
tesseract/client/__init__.py

@@ -1,2 +1,2 @@
-from .gating_function import GatingFunction
+from .remote_moe import RemoteMixtureOfExperts
 from .remote_expert import RemoteExpert

+ 28 - 9
tesseract/client/gating_function.py → tesseract/client/remote_moe.py

@@ -1,5 +1,6 @@
 import multiprocessing as mp
 import multiprocessing.pool
+from concurrent.futures import Future, as_completed
 from functools import partial
 from typing import Tuple, List, Dict, Any
 
@@ -8,12 +9,17 @@ import torch
 import torch.nn as nn
 
 from .remote_expert import RemoteExpert
-from ..utils import nested_map, check_numpy, run_and_await_k
+from ..utils import nested_map, check_numpy, run_in_background
 
 
-class GatingFunction(nn.Module):
+class RemoteMixtureOfExperts(nn.Module):
     """
-    A torch module that selects experts across the network and averages their predictions
+    A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
+    Natively supports pytorch autograd.
+
+    :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
+     forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
+     the missing experts
 
     :param in_features: common input size for experts and gating function
     :param grid_size: tesseract dimensions that form expert uid (see below)
@@ -140,16 +146,29 @@ class GatingFunction(nn.Module):
             for row in beam]
 
     def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
-        outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
-                                  k=self.k_min, timeout_after_k=self.timeout_after_k_min)
-        return {expert: output for expert, output in zip(experts, outputs)
-                if not isinstance(output, BaseException)}
+        future_to_expert = {run_in_background(expert, *args, **kwargs): expert for expert in experts}
+        pending_futures = set(future_to_expert.keys())
+        outputs = {}  # {expert -> output}
+
+        # await first k futures for as long as it takes
+        for future in as_completed(list(pending_futures), timeout=None):
+            pending_futures.remove(future)
+            if not future.exception():
+                outputs[future_to_expert.pop(future)] = future.result()
+                if len(outputs) > self.k_min:
+                    break
+
+        # await stragglers for at most self.timeout_after_k_min
+        for future in as_completed(pending_futures, timeout=self.timeout_after_k_min):
+            if not future.exception():
+                outputs[future_to_expert.pop(future)] = future.result()
+
+        return outputs
 
     def _score_experts(self, grid_scores: List[torch.Tensor],
                        experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
         flat_experts = [expert for row in experts for expert in row]
-        flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
-                                           for uid in range(len(row))])
+        flat_batch_indices = torch.tensor([i for i, row in enumerate(experts) for uid in range(len(row))])
 
         grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
         for i, expert in enumerate(flat_experts):

+ 1 - 1
tesseract/utils/threading.py

@@ -4,7 +4,7 @@ from itertools import count
 from threading import Thread, Event, Lock
 
 
-def run_in_background(func: callable, *args, **kwargs):
+def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run f(*args, **kwargs) in background and return Future for its outputs """
     future = Future()