|
@@ -571,7 +571,6 @@ class HeadExpert(nn.Module):
|
|
|
class StandardHeadExpert(HeadExpert):
|
|
|
def __init__(self, hid_dim):
|
|
|
super().__init__(hid_dim)
|
|
|
- self.config.num_hidden_layers = 1
|
|
|
|
|
|
|
|
|
body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)
|