Ver Fonte

get output schema lazily

justheuristic há 5 anos atrás
pai
commit
bf1dcc15ef
1 ficheiros alterados com 2 adições e 2 exclusões
  1. 2 2
      tesseract/client/moe.py

+ 2 - 2
tesseract/client/moe.py

@@ -140,7 +140,7 @@ class RemoteMixtureOfExperts(nn.Module):
         unique_experts = self.network.get_experts(list(set(
             uid for row in beam for uid in row if uid != self.expert_padding)))
         if self._outputs_schema is None:
-            self._output_schema = next(iter(unique_experts)).info['output_schema']
+            self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
         unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
 
         return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
@@ -174,7 +174,7 @@ class RemoteMixtureOfExperts(nn.Module):
         return scores
 
     @property
-    def output_schema(self):
+    def outputs_schema(self):
         if self._outputs_schema is None:
             # grab some expert to set ensemble output shape
             dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)