瀏覽代碼

get output schema lazily

justheuristic 5 年之前
父節點
當前提交
b2c2b22f26
共有 1 個文件被更改,包括 3 次插入3 次删除
  1. 3 3
      tesseract/client/moe.py

+ 3 - 3
tesseract/client/moe.py

@@ -37,14 +37,14 @@ class RemoteMixtureOfExperts(nn.Module):
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=False will raise an error
     """
-    def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
-                 k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
+    def __init__(self, *, in_features, grid_size: Tuple[int], network, k_best, k_min=1,
+                 forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
                  uid_prefix='', expert_padding=None, allow_broadcasting=True):
         super().__init__()
         self.network, self.grid_size = network, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
-        self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
+        self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
         self.allow_broadcasting = allow_broadcasting
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions