|
@@ -26,7 +26,7 @@ class GatingFunction(nn.Module):
|
|
"""
|
|
"""
|
|
Choose k best experts with beam search, then call chosen experts and average their outputs.
|
|
Choose k best experts with beam search, then call chosen experts and average their outputs.
|
|
:param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
|
|
:param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
|
|
- :return: averaged predictions of all experts that delivered on time
|
|
|
|
|
|
+ :returns: averaged predictions of all experts that delivered on time
|
|
"""
|
|
"""
|
|
assert len(input.shape) == 2
|
|
assert len(input.shape) == 2
|
|
|
|
|
|
@@ -68,12 +68,12 @@ class GatingFunction(nn.Module):
|
|
def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
|
|
def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
|
|
"""
|
|
"""
|
|
Find and return k best experts in the grid using (exact) beam search of the product space
|
|
Find and return k best experts in the grid using (exact) beam search of the product space
|
|
|
|
+
|
|
:param grid_scores: scores predicted for each dimension in the grid,
|
|
:param grid_scores: scores predicted for each dimension in the grid,
|
|
:type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
|
|
:type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
|
|
:param k_best: how many of the top experts participate in the computation
|
|
:param k_best: how many of the top experts participate in the computation
|
|
:param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
:param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
- :returns: a list of *batch_size* lists that contain chosen experts for one sample
|
|
|
|
- each inner list contains RemoteExpert instances for *up to* k_best experts
|
|
|
|
|
|
+ :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains RemoteExpert instances for *up to* k_best experts
|
|
"""
|
|
"""
|
|
assert len(grid_scores) == len(self.grid_size)
|
|
assert len(grid_scores) == len(self.grid_size)
|
|
assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
|
|
assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
|