|
@@ -18,13 +18,13 @@ class GatingFunction(nn.Module):
|
|
|
:param in_features: common input size for experts and gating function
|
|
|
:param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
:param uid_prefix: common prefix for all expert uids
|
|
|
- expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
+ expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
:param network: TesseractNetwork where the experts reside
|
|
|
:param num_workers: number of threads for parallel network operation
|
|
|
:param k_best: queries this many experts with highest scores
|
|
|
:param k_min: makes sure at least this many experts returned output
|
|
|
:param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
- Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
+ Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
:param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
"""
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
@@ -90,7 +90,7 @@ class GatingFunction(nn.Module):
|
|
|
:param k_best: how many of the top experts participate in the computation
|
|
|
:param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
|
:returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
|
|
|
- RemoteExpert instances for *up to* k_best experts
|
|
|
+ RemoteExpert instances for *up to* k_best experts
|
|
|
"""
|
|
|
assert len(grid_scores) == len(self.grid_size)
|
|
|
assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
|