|
@@ -19,13 +19,13 @@ class GatingFunction(nn.Module):
|
|
|
:param in_features: common input size for experts and gating function
|
|
|
:param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
:param uid_prefix: common prefix for all expert uids
|
|
|
- expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
+ expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
:param network: TesseractNetwork where the experts reside
|
|
|
:param num_workers: number of threads for parallel network operation
|
|
|
:param k_best: queries this many experts with highest scores
|
|
|
:param k_min: makes sure at least this many experts returned output
|
|
|
:param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
- Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
+ Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
:param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
"""
|
|
|
super().__init__()
|
|
@@ -88,7 +88,7 @@ class GatingFunction(nn.Module):
|
|
|
:type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
|
|
|
:param k_best: how many of the top experts participate in the computation
|
|
|
:param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
|
- :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains
|
|
|
+ :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
|
|
|
RemoteExpert instances for *up to* k_best experts
|
|
|
"""
|
|
|
assert len(grid_scores) == len(self.grid_size)
|