|
@@ -12,23 +12,23 @@ from ..utils import nested_map, check_numpy, run_and_await_k
|
|
|
|
|
|
|
|
|
class GatingFunction(nn.Module):
|
|
|
+ """
|
|
|
+ A torch module that selects experts across the network and averages their predictions
|
|
|
+
|
|
|
+ :param in_features: common input size for experts and gating function
|
|
|
+ :param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
+ :param uid_prefix: common prefix for all expert uids
|
|
|
+ expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
+ :param network: TesseractNetwork where the experts reside
|
|
|
+ :param num_workers: number of threads for parallel network operation
|
|
|
+ :param k_best: queries this many experts with highest scores
|
|
|
+ :param k_min: makes sure at least this many experts returned output
|
|
|
+ :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
+ Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
+ :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
+ """
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
|
k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
|
|
|
- """
|
|
|
- A torch module that selects experts across the network and averages their predictions
|
|
|
-
|
|
|
- :param in_features: common input size for experts and gating function
|
|
|
- :param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
- :param uid_prefix: common prefix for all expert uids
|
|
|
- expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
- :param network: TesseractNetwork where the experts reside
|
|
|
- :param num_workers: number of threads for parallel network operation
|
|
|
- :param k_best: queries this many experts with highest scores
|
|
|
- :param k_min: makes sure at least this many experts returned output
|
|
|
- :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
- Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
- :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
- """
|
|
|
super().__init__()
|
|
|
self.network, self.grid_size = network, grid_size
|
|
|
self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|