Sfoglia il codice sorgente

sphix doc style guide

justheuristic 5 anni fa
parent
commit
0ae8de607d
1 ha cambiato i file con 17 aggiunte e 1 eliminazioni
  1. 17 1
      tesseract/client/gating_function.py

+ 17 - 1
tesseract/client/gating_function.py

@@ -14,6 +14,20 @@ from ..utils import nested_map, check_numpy, run_and_await_k
 class GatingFunction(nn.Module):
     def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
                  k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
+        """
+        A torch module that selects experts across the network and averages their predictions
+        :param in_features: common input size for experts and gating function
+        :param grid_size: tesseract dimensions that form expert uid (see below)
+        :param uid_prefix: common prefix for all expert uids
+            expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
+        :param network: TesseractNetwork where the experts reside
+        :param num_workers: number of threads for parallel network operation
+        :param k_best: queries this many experts with highest scores
+        :param k_min: makes sure at least this many experts returned output
+        :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
+            Any expert that didn't manage to return output after that delay is considered unavailable
+        :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
+        """
         super().__init__()
         self.network, self.grid_size = network, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
@@ -25,6 +39,7 @@ class GatingFunction(nn.Module):
     def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
         """
         Choose k best experts with beam search, then call chosen experts and average their outputs.
+
         :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
         :returns: averaged predictions of all experts that delivered on time
         """
@@ -73,7 +88,8 @@ class GatingFunction(nn.Module):
         :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
         :param k_best: how many of the top experts participate in the computation
         :param kwargs: extra keyword parameters passed to self.network.first_k_active
-        :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains RemoteExpert instances for *up to* k_best experts
+        :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains
+        RemoteExpert instances for *up to* k_best experts
         """
         assert len(grid_scores) == len(self.grid_size)
         assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)