Max Ryabinin 4 gadi atpakaļ
vecāks
revīzija
5ff2963a75

+ 5 - 7
hivemind/moe/client/moe.py

@@ -15,7 +15,7 @@ from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_map, nested_pack
+from hivemind.utils import nested_flatten, nested_map, nested_pack, nested_compare
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -84,13 +84,9 @@ class RemoteMixtureOfExperts(nn.Module):
         :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
         :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
         """
-        if input.ndim != 2:
-            input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1)))
-        else:
-            input_for_gating = input
 
         # 1. compute scores and find most appropriate experts with beam search
-        grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
+        grid_scores = [torch.randn(input.size(0), dim) for dim in self.beam_search.grid_size]
 
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
@@ -107,6 +103,8 @@ class RemoteMixtureOfExperts(nn.Module):
             except grpc.RpcError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
+        flat_inputs = nested_flatten(((input, *args), kwargs))
+
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             chosen_experts,
@@ -118,7 +116,7 @@ class RemoteMixtureOfExperts(nn.Module):
             self.detect_anomalies,
             self.allow_zero_outputs,
             self.info,
-            *nested_flatten(((input, *args), kwargs)),
+            *flat_inputs,
         )
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 

+ 174 - 87
hivemind/moe/server/layers/albert.py

@@ -55,26 +55,24 @@ class LeanAlbertConfig(AlbertConfig):
 class LeanFFN(nn.Module):
     """
     A transformer FFN module that doesn't hog your GPU memory.
-    Complete with pre-LayerNorm, residual connections and dropout.
+    Complete with pre-LayerNorm and residual connections.
     :param gated: use gated activations based on https://arxiv.org/abs/2002.05202 and https://arxiv.org/abs/2102.11972
       note: gated activations require 1.5x more parameters compared to their non-gated variants.
     """
 
     def __init__(
-        self,
-        hidden_size: int,
-        intermediate_size: int,
-        activation=F.gelu,
-        gated: bool = False,
-        layer_norm_eps: float = 1e-12,
-        dropout: float = 0.0,
+            self,
+            hidden_size: int,
+            intermediate_size: int,
+            activation=F.gelu,
+            gated: bool = False,
+            layer_norm_eps: float = 1e-12,
     ):
         super().__init__()
         self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
         self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
         self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
         self.activation = activation
-        self.dropout = dropout
 
     def forward(self, input):
         return _LeanFFN.apply(
@@ -86,7 +84,6 @@ class LeanFFN(nn.Module):
             self.dense_h2o.weight,
             self.dense_h2o.bias,
             self.activation,
-            self.dropout,
             self.training,
             self.layer_norm.eps,
         )
@@ -106,20 +103,19 @@ class _LeanFFN(torch.autograd.Function):
     @staticmethod
     @custom_fwd
     def forward(
-        ctx,
-        input,
-        ln_weight,
-        ln_bias,
-        i2h_weight,
-        i2h_bias,
-        h2o_weight,
-        h2o_bias,
-        activation,
-        dropout,
-        training,
-        ln_eps,
+            ctx,
+            input,
+            ln_weight,
+            ln_bias,
+            i2h_weight,
+            i2h_bias,
+            h2o_weight,
+            h2o_bias,
+            activation,
+            training,
+            ln_eps,
     ):
-        ctx._activation, ctx._dropout, ctx._training, ctx._ln_eps = activation, dropout, training, ln_eps
+        ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
         ctx._cpu_rng_state = torch.get_rng_state()
         ctx._device_rng_states = get_device_states(input)
 
@@ -131,7 +127,6 @@ class _LeanFFN(torch.autograd.Function):
         hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, h2o_weight.shape[1])
 
         out = F.linear(hid_act, h2o_weight, h2o_bias)
-        out = F.dropout(out, dropout, training, inplace=True)
         out = out.add_(input_2d)
         ctx.save_for_backward(input, pre_activation, ln_weight, ln_bias, i2h_weight, h2o_weight)
         return out.view(*input.shape)
@@ -184,17 +179,17 @@ class _LeanFFN(torch.autograd.Function):
             grad_h2o_bias = grad_output_2d.sum(0)
 
         return (
-            grad_input,
-            grad_ln_weight,
-            grad_ln_bias,
-            grad_i2h_weight,
-            grad_i2h_bias,
-            grad_h2o_weight,
-            grad_h2o_bias,
-            None,
-            None,
-            None,
-            None,
+                grad_input,
+                grad_ln_weight,
+                grad_ln_bias,
+                grad_i2h_weight,
+                grad_i2h_bias,
+                grad_h2o_weight,
+                grad_h2o_bias,
+                None,
+                None,
+                None,
+                None,
         )
 
 
@@ -217,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
             self.register_buffer("cos", cos)
             self.register_buffer("sin", sin)
 
-        return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
+        return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
 
 
 @torch.no_grad()
@@ -248,18 +243,18 @@ def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tenso
 
 class LeanSelfAttention(nn.Module):
     def __init__(
-        self,
-        hidden_size: int,
-        num_attention_heads: int,
-        attention_core: Optional[nn.Module] = None,
-        hidden_dropout_prob: float = 0,
-        layer_norm_eps: float = 1e-12,
-        **kwargs,
+            self,
+            hidden_size: int,
+            num_attention_heads: int,
+            max_positions: int,
+            attention_core: Optional[nn.Module] = None,
+            layer_norm_eps: float = 1e-12,
+            **kwargs,
     ):
         """Attention layer that does not hog GPU memory"""
         super().__init__()
         if attention_core is None:
-            attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, **kwargs)
+            attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, max_positions, **kwargs)
         else:
             assert len(kwargs) == 0, f"Unexpected parameters: {kwargs}"
 
@@ -268,7 +263,6 @@ class LeanSelfAttention(nn.Module):
         self.dense_qkv = nn.Linear(hidden_size, hidden_size * 3)
         self.dense_out = nn.Linear(hidden_size, hidden_size)
         self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
-        self.output_dropout = nn.Dropout(hidden_dropout_prob, inplace=False)
 
     def forward(self, hidden_states, attention_mask=None, output_attentions=False):
         hidden_states_ln = self.layer_norm(hidden_states)
@@ -276,19 +270,25 @@ class LeanSelfAttention(nn.Module):
         query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1)
         attention_output, attention_probs = checkpoint(self.attention_core, query, key, value, attention_mask)
         projected_context_layer = self.dense_out(attention_output)
-        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
-        layernormed_context_layer = projected_context_layer_dropout + hidden_states
+        layernormed_context_layer = projected_context_layer + hidden_states
         return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
 
 
 class SimpleAttentionCore(nn.Module):
-    def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0):
+    def __init__(self, hidden_size: int, num_attention_heads: int, max_positions):
         super().__init__()
         assert hidden_size % num_attention_heads == 0
-        self.attention_dropout = nn.Dropout(attention_probs_dropout, inplace=False)
         self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads
         self.attention_head_size = hidden_size // num_attention_heads
 
+        self.register_buffer(
+            "bias",
+            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
+                1, 1, max_positions, max_positions
+            ),
+        )
+        self.register_buffer("masked_bias", torch.tensor(-1e4))
+
     def transpose_for_scores(self, x):
         new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
         x = x.view(*new_x_shape)
@@ -310,6 +310,10 @@ class SimpleAttentionCore(nn.Module):
         attention_scores = torch.matmul(query, key.transpose(-1, -2))
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
 
+        query_length, key_length = query.size(-2), key.size(-2)
+        causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
+        attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
+
         if attention_mask is not None:
             # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
             attention_scores = attention_scores + attention_mask
@@ -317,10 +321,6 @@ class SimpleAttentionCore(nn.Module):
         # Normalize the attention scores to probabilities.
         attention_probs = torch.softmax(attention_scores, dim=-1)
 
-        # This is actually dropping out entire tokens to attend to, which might
-        # seem a bit unusual, but is taken from the original Transformer paper.
-        attention_probs = self.attention_dropout(attention_probs)
-
         attention_output = torch.matmul(attention_probs, value)
         attention_output = attention_output.transpose(2, 1).flatten(2)
         return attention_output, attention_probs
@@ -330,9 +330,14 @@ class RotaryAttentionCore(SimpleAttentionCore):
     """Attention core that applies rotary embeddings to queries and keys before computing dot products"""
 
     def __init__(
-        self, hidden_size: int, num_attention_heads: int, rotary_emb: Optional[RotaryEmbeddings] = None, **kwargs
+            self,
+            hidden_size: int,
+            num_attention_heads: int,
+            max_positions: int,
+            rotary_emb: Optional[RotaryEmbeddings] = None,
+            **kwargs,
     ):
-        super().__init__(hidden_size, num_attention_heads, **kwargs)
+        super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
         if rotary_emb is None:
             rotary_emb = RotaryEmbeddings(self.attention_head_size)
         self.rotary_emb = rotary_emb
@@ -362,10 +367,7 @@ def get_attention_core(config: LeanAlbertConfig):
     elif config.position_embedding_type == "rotary":
         rotary_emb = RotaryEmbeddings(config.hidden_size // config.num_attention_heads, config.rotary_embedding_base)
         return RotaryAttentionCore(
-            config.hidden_size,
-            config.num_attention_heads,
-            rotary_emb,
-            attention_probs_dropout=config.attention_probs_dropout_prob,
+            config.hidden_size, config.num_attention_heads, config.max_position_embeddings, rotary_emb
         )
     else:
         raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding_type}")
@@ -383,7 +385,6 @@ class LeanAlbertEmbeddings(nn.Module):
         self.position_embeddings = get_input_embedding(config)
 
         self.layernorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
-        self.dropout = nn.Dropout(config.hidden_dropout_prob)
 
         if self.position_embeddings is not None:
             # position_ids (1, len position emb) is contiguous in memory and exported when serialized
@@ -392,7 +393,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
     # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
     def forward(
-        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+            self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
     ):
         if input_ids is not None:
             input_shape = input_ids.size()
@@ -412,12 +413,11 @@ class LeanAlbertEmbeddings(nn.Module):
 
         if self.position_embeddings is not None:
             if position_ids is None:
-                position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+                position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
             position_embeddings = self.position_embeddings(position_ids)
             embeddings += position_embeddings
 
         embeddings = self.layernorm(embeddings)
-        embeddings = self.dropout(embeddings)
         return embeddings
 
 
@@ -433,8 +433,8 @@ class LeanAlbertLayer(nn.Module):
         self.attention = LeanSelfAttention(
             config.hidden_size,
             config.num_attention_heads,
+            config.max_position_embeddings,
             attention_core=get_attention_core(config),
-            hidden_dropout_prob=config.hidden_dropout_prob,
             layer_norm_eps=config.layer_norm_eps,
         )
 
@@ -444,7 +444,6 @@ class LeanAlbertLayer(nn.Module):
             activation=ACT2FN[config.hidden_act],
             gated=config.hidden_act_gated,
             layer_norm_eps=config.layer_norm_eps,
-            dropout=config.hidden_dropout_prob,
         )
 
     def forward(self, hidden_states, attention_mask=None, output_attentions=False):
@@ -459,9 +458,10 @@ class LeanAlbertLayerGroup(AlbertLayerGroup):
         self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
 
     def forward(
-        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
+            self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False,
+            output_hidden_states=False
     ):
-        if any(head_mask):
+        if head_mask is not None and any(head_mask):
             raise NotImplementedError(f"head mask was provided, but it is not supported")
 
         layer_hidden_states = ()
@@ -496,13 +496,13 @@ class LeanAlbertTransformer(AlbertTransformer):
         self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
 
     def forward(
-        self,
-        hidden_states,
-        attention_mask=None,
-        head_mask=None,
-        output_attentions=False,
-        output_hidden_states=False,
-        return_dict=True,
+            self,
+            hidden_states,
+            attention_mask=None,
+            head_mask=None,
+            output_attentions=False,
+            output_hidden_states=False,
+            return_dict=True,
     ):
         # TODO this should entire be replaced with inheritance and post_layer_norm
         hidden_states = self.embedding_hidden_mapping_in(hidden_states)
@@ -520,7 +520,7 @@ class LeanAlbertTransformer(AlbertTransformer):
             layer_group_output = self.albert_layer_groups[group_idx](
                 hidden_states,
                 attention_mask,
-                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
+                None,
                 output_attentions,
                 output_hidden_states,
             )
@@ -532,11 +532,13 @@ class LeanAlbertTransformer(AlbertTransformer):
             if output_hidden_states:
                 all_hidden_states = all_hidden_states + (hidden_states,)
 
+        hidden_states = self.post_layer_norm(hidden_states)
+
         if not return_dict:
             return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
 
         return BaseModelOutput(
-            last_hidden_state=self.post_layer_norm(hidden_states),
+            last_hidden_state=hidden_states,
             hidden_states=all_hidden_states,
             attentions=all_attentions,
         )
@@ -582,24 +584,109 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
 
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 
-albert_sample_input = lambda batch_size, hid_dim: (
-    torch.empty((batch_size, 512, hid_dim)),
-    torch.ones((batch_size, 512)),
+head_sample_input = lambda batch_size, hid_dim: (
+        torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
+)
+
+
+@register_expert_class("lm_head", head_sample_input)
+class HeadExpert(nn.Module):
+    def __init__(self, hid_dim):
+        super().__init__()
+        config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
+        config.hidden_size = hid_dim
+        config.num_hidden_layers = 12
+
+        self.encoder = LeanAlbertTransformer(config)
+        self.embeddings = LeanAlbertEmbeddings(config)
+
+    def forward(self, input_ids):
+        embedding_output = self.embeddings(input_ids)
+        encoder_outputs, = self.encoder(embedding_output, return_dict=False)
+
+        return encoder_outputs
+
+
+body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, 512, hid_dim)),)
+
+
+@register_expert_class("lm_body", body_sample_input)
+class BodyExpert(nn.Module):
+    def __init__(self, hid_dim):
+        super().__init__()
+        config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
+        config.hidden_size = hid_dim
+        config.num_hidden_layers = 12
+
+        self.config = config
+        self.albert_layer_groups = nn.ModuleList(
+            [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
+        )
+        self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        for i in range(self.config.num_hidden_layers):
+            # Index of the hidden group
+            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
+
+            layer_group_output = self.albert_layer_groups[group_idx](
+                hidden_states,
+                None,
+                None,
+                False,
+                False,
+            )
+            hidden_states = layer_group_output[0]
+
+        hidden_states = self.post_layer_norm(hidden_states)
+
+        return hidden_states
+
+
+tail_sample_input = lambda batch_size, hid_dim: (
+
+        torch.empty((batch_size, 512, hid_dim)),
+        torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
 )
 
 
-@register_expert_class("albert", albert_sample_input)
-class LeanAlbertExpert(nn.Module):
+@register_expert_class("lm_tail", tail_sample_input)
+class TailExpert(nn.Module):
     def __init__(self, hid_dim):
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
+        config.num_hidden_layers = 12
+
+        self.config = config
+        self.albert_layer_groups = nn.ModuleList(
+            [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
+        )
+        self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
+
+        self.lm_head = AlbertMLMHead(config)
+
+    def forward(self, hidden_states, labels):
+        for i in range(self.config.num_hidden_layers):
+            # Index of the hidden group
+            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
+
+            layer_group_output = self.albert_layer_groups[group_idx](
+                hidden_states,
+                None,
+                None,
+                False,
+                False,
+            )
+            hidden_states = layer_group_output[0]
 
-        self.layer = LeanAlbertLayer(config)
+        hidden_states = self.post_layer_norm(hidden_states)
 
-    def forward(self, x, attention_mask):
-        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
-        extended_attention_mask = extended_attention_mask.to(dtype=x.dtype)  # fp16 compatibility
-        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+        lm_logits = self.lm_head(hidden_states)
 
-        return self.layer(x, attention_mask=extended_attention_mask)[0]
+        # Shift so that tokens < n predict n
+        shift_logits = lm_logits[..., :-1, :].contiguous()
+        shift_labels = labels[..., 1:].contiguous()
+        # Flatten the tokens
+        loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
+        return loss

+ 18 - 0
hivemind/utils/nested.py

@@ -1,4 +1,6 @@
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
+import torch
+from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
 
 
 def nested_compare(t, u):
@@ -25,6 +27,22 @@ def nested_compare(t, u):
                 return False
         return True
 
+    # if isinstance(t, torch.Tensor):
+    #     if isinstance(u, torch.Tensor) and t.size() != u.size():
+    #         return False
+    #     if isinstance(u, TensorDescriptor) and t.size() != u.size:
+    #         return False
+    #     if isinstance(u, BatchTensorDescriptor) and t.size()[1:] != u.size[1:]:
+    #         return False
+
+    # if hasattr(t, '__iter__'):
+    #     print(t, u)
+    #     if not hasattr(u, '__iter__'):
+    #         return False
+    #     for a, b in zip(t, u):
+    #         if not nested_compare(a, b):
+    #             return False
+
     else:
         return True
 

+ 3 - 0
hivemind/utils/tensor_descr.py

@@ -49,6 +49,9 @@ class TensorDescriptor(DescriptorBase):
         properties = asdict(self)
         properties.update(kwargs)
         properties.pop("compression")
+        if self.dtype == torch.long:
+            return torch.zeros(**properties)
+
         return torch.empty(**properties)