|
@@ -1,5 +1,6 @@
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.pool
|
|
|
+from concurrent.futures import Future, as_completed
|
|
|
from functools import partial
|
|
|
from typing import Tuple, List, Dict, Any
|
|
|
|
|
@@ -8,12 +9,17 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from .remote_expert import RemoteExpert
|
|
|
-from ..utils import nested_map, check_numpy, run_and_await_k
|
|
|
+from ..utils import nested_map, check_numpy, run_in_background
|
|
|
|
|
|
|
|
|
-class GatingFunction(nn.Module):
|
|
|
+class RemoteMixtureOfExperts(nn.Module):
|
|
|
"""
|
|
|
- A torch module that selects experts across the network and averages their predictions
|
|
|
+ A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
|
|
|
+ Natively supports pytorch autograd.
|
|
|
+
|
|
|
+ :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
|
|
|
+ forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
|
|
|
+ the missing experts
|
|
|
|
|
|
:param in_features: common input size for experts and gating function
|
|
|
:param grid_size: tesseract dimensions that form expert uid (see below)
|
|
@@ -140,16 +146,29 @@ class GatingFunction(nn.Module):
|
|
|
for row in beam]
|
|
|
|
|
|
def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
|
|
|
- outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
|
|
|
- k=self.k_min, timeout_after_k=self.timeout_after_k_min)
|
|
|
- return {expert: output for expert, output in zip(experts, outputs)
|
|
|
- if not isinstance(output, BaseException)}
|
|
|
+ future_to_expert = {run_in_background(expert, *args, **kwargs): expert for expert in experts}
|
|
|
+ pending_futures = set(future_to_expert.keys())
|
|
|
+ outputs = {} # {expert -> output}
|
|
|
+
|
|
|
+ # await first k futures for as long as it takes
|
|
|
+ for future in as_completed(list(pending_futures), timeout=None):
|
|
|
+ pending_futures.remove(future)
|
|
|
+ if not future.exception():
|
|
|
+ outputs[future_to_expert.pop(future)] = future.result()
|
|
|
+ if len(outputs) > self.k_min:
|
|
|
+ break
|
|
|
+
|
|
|
+ # await stragglers for at most self.timeout_after_k_min
|
|
|
+ for future in as_completed(pending_futures, timeout=self.timeout_after_k_min):
|
|
|
+ if not future.exception():
|
|
|
+ outputs[future_to_expert.pop(future)] = future.result()
|
|
|
+
|
|
|
+ return outputs
|
|
|
|
|
|
def _score_experts(self, grid_scores: List[torch.Tensor],
|
|
|
experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
|
|
|
flat_experts = [expert for row in experts for expert in row]
|
|
|
- flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
|
|
|
- for uid in range(len(row))])
|
|
|
+ flat_batch_indices = torch.tensor([i for i, row in enumerate(experts) for uid in range(len(row))])
|
|
|
|
|
|
grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
|
|
|
for i, expert in enumerate(flat_experts):
|