|
@@ -61,12 +61,12 @@ class LeanFFN(nn.Module):
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
- self,
|
|
|
|
- hidden_size: int,
|
|
|
|
- intermediate_size: int,
|
|
|
|
- activation=F.gelu,
|
|
|
|
- gated: bool = False,
|
|
|
|
- layer_norm_eps: float = 1e-12,
|
|
|
|
|
|
+ self,
|
|
|
|
+ hidden_size: int,
|
|
|
|
+ intermediate_size: int,
|
|
|
|
+ activation=F.gelu,
|
|
|
|
+ gated: bool = False,
|
|
|
|
+ layer_norm_eps: float = 1e-12,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
|
|
self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
|
|
@@ -103,17 +103,17 @@ class _LeanFFN(torch.autograd.Function):
|
|
@staticmethod
|
|
@staticmethod
|
|
@custom_fwd
|
|
@custom_fwd
|
|
def forward(
|
|
def forward(
|
|
- ctx,
|
|
|
|
- input,
|
|
|
|
- ln_weight,
|
|
|
|
- ln_bias,
|
|
|
|
- i2h_weight,
|
|
|
|
- i2h_bias,
|
|
|
|
- h2o_weight,
|
|
|
|
- h2o_bias,
|
|
|
|
- activation,
|
|
|
|
- training,
|
|
|
|
- ln_eps,
|
|
|
|
|
|
+ ctx,
|
|
|
|
+ input,
|
|
|
|
+ ln_weight,
|
|
|
|
+ ln_bias,
|
|
|
|
+ i2h_weight,
|
|
|
|
+ i2h_bias,
|
|
|
|
+ h2o_weight,
|
|
|
|
+ h2o_bias,
|
|
|
|
+ activation,
|
|
|
|
+ training,
|
|
|
|
+ ln_eps,
|
|
):
|
|
):
|
|
ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
|
|
ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
|
|
ctx._cpu_rng_state = torch.get_rng_state()
|
|
ctx._cpu_rng_state = torch.get_rng_state()
|
|
@@ -179,17 +179,17 @@ class _LeanFFN(torch.autograd.Function):
|
|
grad_h2o_bias = grad_output_2d.sum(0)
|
|
grad_h2o_bias = grad_output_2d.sum(0)
|
|
|
|
|
|
return (
|
|
return (
|
|
- grad_input,
|
|
|
|
- grad_ln_weight,
|
|
|
|
- grad_ln_bias,
|
|
|
|
- grad_i2h_weight,
|
|
|
|
- grad_i2h_bias,
|
|
|
|
- grad_h2o_weight,
|
|
|
|
- grad_h2o_bias,
|
|
|
|
- None,
|
|
|
|
- None,
|
|
|
|
- None,
|
|
|
|
- None,
|
|
|
|
|
|
+ grad_input,
|
|
|
|
+ grad_ln_weight,
|
|
|
|
+ grad_ln_bias,
|
|
|
|
+ grad_i2h_weight,
|
|
|
|
+ grad_i2h_bias,
|
|
|
|
+ grad_h2o_weight,
|
|
|
|
+ grad_h2o_bias,
|
|
|
|
+ None,
|
|
|
|
+ None,
|
|
|
|
+ None,
|
|
|
|
+ None,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -212,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
|
|
self.register_buffer("cos", cos)
|
|
self.register_buffer("cos", cos)
|
|
self.register_buffer("sin", sin)
|
|
self.register_buffer("sin", sin)
|
|
|
|
|
|
- return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
|
|
|
|
|
|
+ return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
@@ -243,13 +243,13 @@ def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tenso
|
|
|
|
|
|
class LeanSelfAttention(nn.Module):
|
|
class LeanSelfAttention(nn.Module):
|
|
def __init__(
|
|
def __init__(
|
|
- self,
|
|
|
|
- hidden_size: int,
|
|
|
|
- num_attention_heads: int,
|
|
|
|
- max_positions: int,
|
|
|
|
- attention_core: Optional[nn.Module] = None,
|
|
|
|
- layer_norm_eps: float = 1e-12,
|
|
|
|
- **kwargs,
|
|
|
|
|
|
+ self,
|
|
|
|
+ hidden_size: int,
|
|
|
|
+ num_attention_heads: int,
|
|
|
|
+ max_positions: int,
|
|
|
|
+ attention_core: Optional[nn.Module] = None,
|
|
|
|
+ layer_norm_eps: float = 1e-12,
|
|
|
|
+ **kwargs,
|
|
):
|
|
):
|
|
"""Attention layer that does not hog GPU memory"""
|
|
"""Attention layer that does not hog GPU memory"""
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -311,7 +311,7 @@ class SimpleAttentionCore(nn.Module):
|
|
attention_scores = attention_scores / math.sqrt(query.shape[-1])
|
|
attention_scores = attention_scores / math.sqrt(query.shape[-1])
|
|
|
|
|
|
query_length, key_length = query.size(-2), key.size(-2)
|
|
query_length, key_length = query.size(-2), key.size(-2)
|
|
- causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
|
|
|
|
|
|
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
|
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
|
|
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
|
|
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask is not None:
|
|
@@ -330,12 +330,12 @@ class RotaryAttentionCore(SimpleAttentionCore):
|
|
"""Attention core that applies rotary embeddings to queries and keys before computing dot products"""
|
|
"""Attention core that applies rotary embeddings to queries and keys before computing dot products"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
- self,
|
|
|
|
- hidden_size: int,
|
|
|
|
- num_attention_heads: int,
|
|
|
|
- max_positions: int,
|
|
|
|
- rotary_emb: Optional[RotaryEmbeddings] = None,
|
|
|
|
- **kwargs,
|
|
|
|
|
|
+ self,
|
|
|
|
+ hidden_size: int,
|
|
|
|
+ num_attention_heads: int,
|
|
|
|
+ max_positions: int,
|
|
|
|
+ rotary_emb: Optional[RotaryEmbeddings] = None,
|
|
|
|
+ **kwargs,
|
|
):
|
|
):
|
|
super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
|
|
super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
|
|
if rotary_emb is None:
|
|
if rotary_emb is None:
|
|
@@ -393,7 +393,7 @@ class LeanAlbertEmbeddings(nn.Module):
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
|
def forward(
|
|
def forward(
|
|
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
|
|
|
|
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
):
|
|
):
|
|
if input_ids is not None:
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
input_shape = input_ids.size()
|
|
@@ -413,7 +413,7 @@ class LeanAlbertEmbeddings(nn.Module):
|
|
|
|
|
|
if self.position_embeddings is not None:
|
|
if self.position_embeddings is not None:
|
|
if position_ids is None:
|
|
if position_ids is None:
|
|
- position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
|
|
|
|
|
|
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings += position_embeddings
|
|
embeddings += position_embeddings
|
|
|
|
|
|
@@ -458,8 +458,7 @@ class LeanAlbertLayerGroup(AlbertLayerGroup):
|
|
self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
|
|
self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
- self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False,
|
|
|
|
- output_hidden_states=False
|
|
|
|
|
|
+ self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
|
|
):
|
|
):
|
|
if head_mask is not None and any(head_mask):
|
|
if head_mask is not None and any(head_mask):
|
|
raise NotImplementedError(f"head mask was provided, but it is not supported")
|
|
raise NotImplementedError(f"head mask was provided, but it is not supported")
|
|
@@ -496,13 +495,13 @@ class LeanAlbertTransformer(AlbertTransformer):
|
|
self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
|
|
self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
- self,
|
|
|
|
- hidden_states,
|
|
|
|
- attention_mask=None,
|
|
|
|
- head_mask=None,
|
|
|
|
- output_attentions=False,
|
|
|
|
- output_hidden_states=False,
|
|
|
|
- return_dict=True,
|
|
|
|
|
|
+ self,
|
|
|
|
+ hidden_states,
|
|
|
|
+ attention_mask=None,
|
|
|
|
+ head_mask=None,
|
|
|
|
+ output_attentions=False,
|
|
|
|
+ output_hidden_states=False,
|
|
|
|
+ return_dict=True,
|
|
):
|
|
):
|
|
# TODO this should entire be replaced with inheritance and post_layer_norm
|
|
# TODO this should entire be replaced with inheritance and post_layer_norm
|
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
|
@@ -585,7 +584,7 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
|
|
from hivemind.moe.server.layers.custom_experts import register_expert_class
|
|
from hivemind.moe.server.layers.custom_experts import register_expert_class
|
|
|
|
|
|
head_sample_input = lambda batch_size, hid_dim: (
|
|
head_sample_input = lambda batch_size, hid_dim: (
|
|
- torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
|
|
|
|
|
|
+ torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -602,7 +601,7 @@ class HeadExpert(nn.Module):
|
|
|
|
|
|
def forward(self, input_ids):
|
|
def forward(self, input_ids):
|
|
embedding_output = self.embeddings(input_ids)
|
|
embedding_output = self.embeddings(input_ids)
|
|
- encoder_outputs, = self.encoder(embedding_output, return_dict=False)
|
|
|
|
|
|
+ (encoder_outputs,) = self.encoder(embedding_output, return_dict=False)
|
|
|
|
|
|
return encoder_outputs
|
|
return encoder_outputs
|
|
|
|
|
|
@@ -644,9 +643,8 @@ class BodyExpert(nn.Module):
|
|
|
|
|
|
|
|
|
|
tail_sample_input = lambda batch_size, hid_dim: (
|
|
tail_sample_input = lambda batch_size, hid_dim: (
|
|
-
|
|
|
|
- torch.empty((batch_size, 512, hid_dim)),
|
|
|
|
- torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
|
|
|
|
|
|
+ torch.empty((batch_size, 512, hid_dim)),
|
|
|
|
+ torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|