|
@@ -14,6 +14,20 @@ from ..utils import nested_map, check_numpy, run_and_await_k
|
|
|
class GatingFunction(nn.Module):
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
|
k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
|
|
|
+ """
|
|
|
+ A torch module that selects experts across the network and averages their predictions
|
|
|
+ :param in_features: common input size for experts and gating function
|
|
|
+ :param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
+ :param uid_prefix: common prefix for all expert uids
|
|
|
+ expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
+ :param network: TesseractNetwork where the experts reside
|
|
|
+ :param num_workers: number of threads for parallel network operation
|
|
|
+ :param k_best: queries this many experts with highest scores
|
|
|
+ :param k_min: makes sure at least this many experts returned output
|
|
|
+ :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
+ Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
+ :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
+ """
|
|
|
super().__init__()
|
|
|
self.network, self.grid_size = network, grid_size
|
|
|
self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
@@ -25,6 +39,7 @@ class GatingFunction(nn.Module):
|
|
|
def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
|
|
|
"""
|
|
|
Choose k best experts with beam search, then call chosen experts and average their outputs.
|
|
|
+
|
|
|
:param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
|
|
|
:returns: averaged predictions of all experts that delivered on time
|
|
|
"""
|
|
@@ -73,7 +88,8 @@ class GatingFunction(nn.Module):
|
|
|
:type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
|
|
|
:param k_best: how many of the top experts participate in the computation
|
|
|
:param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
|
- :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains RemoteExpert instances for *up to* k_best experts
|
|
|
+ :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains
|
|
|
+ RemoteExpert instances for *up to* k_best experts
|
|
|
"""
|
|
|
assert len(grid_scores) == len(self.grid_size)
|
|
|
assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
|