|
@@ -48,7 +48,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
# grab some expert to set ensemble output shape
|
|
|
dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)
|
|
|
- self.output_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['output_schema']
|
|
|
+ self.output_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
|
|
|
|
|
|
def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
|
|
|
"""
|