|
@@ -177,7 +177,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
def outputs_schema(self):
|
|
|
if self._outputs_schema is None:
|
|
|
# grab some expert to set ensemble output shape
|
|
|
- dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(grid_size, dim=-1)
|
|
|
+ dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(self.grid_size, dim=-1)
|
|
|
self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
|
|
|
return self._outputs_schema
|
|
|
|