Max Ryabinin 2 жил өмнө
parent
commit
03ebe9c6c6

+ 21 - 0
hivemind/moe/server/layers/albert.py

@@ -567,6 +567,13 @@ class HeadExpert(nn.Module):
         return encoder_outputs
 
 
+@register_expert_class("lm_head_base", head_sample_input)
+class StandardHeadExpert(HeadExpert):
+    def __init__(self, hid_dim):
+        super().__init__(hid_dim)
+        self.config.num_hidden_layers = 1
+
+
 body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)
 
 
@@ -606,6 +613,13 @@ class BodyExpert(nn.Module):
         return hidden_states
 
 
+@register_expert_class("lm_body_base", body_sample_input)
+class StandardBodyExpert(BodyExpert):
+    def __init__(self, hid_dim):
+        super().__init__(hid_dim)
+        self.config.num_hidden_layers = 1
+
+
 tail_sample_input = lambda batch_size, hid_dim: (
     torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),
     torch.randint(0, 1000, (batch_size, SEQUENCE_LENGTH), dtype=torch.long),
@@ -655,3 +669,10 @@ class TailExpert(nn.Module):
         # Flatten the tokens
         loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
         return loss
+
+
+@register_expert_class("lm_tail_base", tail_sample_input)
+class StandardTailExpert(TailExpert):
+    def __init__(self, hid_dim):
+        super().__init__(hid_dim)
+        self.config.num_hidden_layers = 1