|
@@ -567,6 +567,13 @@ class HeadExpert(nn.Module):
|
|
|
return encoder_outputs
|
|
|
|
|
|
|
|
|
+@register_expert_class("lm_head_base", head_sample_input)
|
|
|
+class StandardHeadExpert(HeadExpert):
|
|
|
+ def __init__(self, hid_dim):
|
|
|
+ super().__init__(hid_dim)
|
|
|
+ self.config.num_hidden_layers = 1
|
|
|
+
|
|
|
+
|
|
|
body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)
|
|
|
|
|
|
|
|
@@ -606,6 +613,13 @@ class BodyExpert(nn.Module):
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
+@register_expert_class("lm_body_base", body_sample_input)
|
|
|
+class StandardBodyExpert(BodyExpert):
|
|
|
+ def __init__(self, hid_dim):
|
|
|
+ super().__init__(hid_dim)
|
|
|
+ self.config.num_hidden_layers = 1
|
|
|
+
|
|
|
+
|
|
|
tail_sample_input = lambda batch_size, hid_dim: (
|
|
|
torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),
|
|
|
torch.randint(0, 1000, (batch_size, SEQUENCE_LENGTH), dtype=torch.long),
|
|
@@ -655,3 +669,10 @@ class TailExpert(nn.Module):
|
|
|
# Flatten the tokens
|
|
|
loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
|
|
|
return loss
|
|
|
+
|
|
|
+
|
|
|
+@register_expert_class("lm_tail_base", tail_sample_input)
|
|
|
+class StandardTailExpert(TailExpert):
|
|
|
+ def __init__(self, hid_dim):
|
|
|
+ super().__init__(hid_dim)
|
|
|
+ self.config.num_hidden_layers = 1
|