@@ -571,7 +571,6 @@ class HeadExpert(nn.Module):
class StandardHeadExpert(HeadExpert):
def __init__(self, hid_dim):
super().__init__(hid_dim)
- self.config.num_hidden_layers = 1
body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)