|
@@ -16,6 +16,7 @@ class GatingFunction(nn.Module):
|
|
|
k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
|
|
|
"""
|
|
|
A torch module that selects experts across the network and averages their predictions
|
|
|
+
|
|
|
:param in_features: common input size for experts and gating function
|
|
|
:param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
:param uid_prefix: common prefix for all expert uids
|