@@ -16,6 +16,7 @@ class GatingFunction(nn.Module):
k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
"""
A torch module that selects experts across the network and averages their predictions
+
:param in_features: common input size for experts and gating function
:param grid_size: tesseract dimensions that form expert uid (see below)
:param uid_prefix: common prefix for all expert uids