|
@@ -55,26 +55,24 @@ class LeanAlbertConfig(AlbertConfig):
|
|
|
class LeanFFN(nn.Module):
|
|
|
"""
|
|
|
A transformer FFN module that doesn't hog your GPU memory.
|
|
|
- Complete with pre-LayerNorm, residual connections and dropout.
|
|
|
+ Complete with pre-LayerNorm and residual connections.
|
|
|
:param gated: use gated activations based on https://arxiv.org/abs/2002.05202 and https://arxiv.org/abs/2102.11972
|
|
|
note: gated activations require 1.5x more parameters compared to their non-gated variants.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- hidden_size: int,
|
|
|
- intermediate_size: int,
|
|
|
- activation=F.gelu,
|
|
|
- gated: bool = False,
|
|
|
- layer_norm_eps: float = 1e-12,
|
|
|
- dropout: float = 0.0,
|
|
|
+ self,
|
|
|
+ hidden_size: int,
|
|
|
+ intermediate_size: int,
|
|
|
+ activation=F.gelu,
|
|
|
+ gated: bool = False,
|
|
|
+ layer_norm_eps: float = 1e-12,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
|
|
|
self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
|
|
|
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
|
|
self.activation = activation
|
|
|
- self.dropout = dropout
|
|
|
|
|
|
def forward(self, input):
|
|
|
return _LeanFFN.apply(
|
|
@@ -86,7 +84,6 @@ class LeanFFN(nn.Module):
|
|
|
self.dense_h2o.weight,
|
|
|
self.dense_h2o.bias,
|
|
|
self.activation,
|
|
|
- self.dropout,
|
|
|
self.training,
|
|
|
self.layer_norm.eps,
|
|
|
)
|
|
@@ -106,20 +103,19 @@ class _LeanFFN(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
@custom_fwd
|
|
|
def forward(
|
|
|
- ctx,
|
|
|
- input,
|
|
|
- ln_weight,
|
|
|
- ln_bias,
|
|
|
- i2h_weight,
|
|
|
- i2h_bias,
|
|
|
- h2o_weight,
|
|
|
- h2o_bias,
|
|
|
- activation,
|
|
|
- dropout,
|
|
|
- training,
|
|
|
- ln_eps,
|
|
|
+ ctx,
|
|
|
+ input,
|
|
|
+ ln_weight,
|
|
|
+ ln_bias,
|
|
|
+ i2h_weight,
|
|
|
+ i2h_bias,
|
|
|
+ h2o_weight,
|
|
|
+ h2o_bias,
|
|
|
+ activation,
|
|
|
+ training,
|
|
|
+ ln_eps,
|
|
|
):
|
|
|
- ctx._activation, ctx._dropout, ctx._training, ctx._ln_eps = activation, dropout, training, ln_eps
|
|
|
+ ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
|
|
|
ctx._cpu_rng_state = torch.get_rng_state()
|
|
|
ctx._device_rng_states = get_device_states(input)
|
|
|
|
|
@@ -131,7 +127,6 @@ class _LeanFFN(torch.autograd.Function):
|
|
|
hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, h2o_weight.shape[1])
|
|
|
|
|
|
out = F.linear(hid_act, h2o_weight, h2o_bias)
|
|
|
- out = F.dropout(out, dropout, training, inplace=True)
|
|
|
out = out.add_(input_2d)
|
|
|
ctx.save_for_backward(input, pre_activation, ln_weight, ln_bias, i2h_weight, h2o_weight)
|
|
|
return out.view(*input.shape)
|
|
@@ -184,17 +179,17 @@ class _LeanFFN(torch.autograd.Function):
|
|
|
grad_h2o_bias = grad_output_2d.sum(0)
|
|
|
|
|
|
return (
|
|
|
- grad_input,
|
|
|
- grad_ln_weight,
|
|
|
- grad_ln_bias,
|
|
|
- grad_i2h_weight,
|
|
|
- grad_i2h_bias,
|
|
|
- grad_h2o_weight,
|
|
|
- grad_h2o_bias,
|
|
|
- None,
|
|
|
- None,
|
|
|
- None,
|
|
|
- None,
|
|
|
+ grad_input,
|
|
|
+ grad_ln_weight,
|
|
|
+ grad_ln_bias,
|
|
|
+ grad_i2h_weight,
|
|
|
+ grad_i2h_bias,
|
|
|
+ grad_h2o_weight,
|
|
|
+ grad_h2o_bias,
|
|
|
+ None,
|
|
|
+ None,
|
|
|
+ None,
|
|
|
+ None,
|
|
|
)
|
|
|
|
|
|
|
|
@@ -217,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
|
|
|
self.register_buffer("cos", cos)
|
|
|
self.register_buffer("sin", sin)
|
|
|
|
|
|
- return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
|
|
|
+ return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
@@ -248,18 +243,18 @@ def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tenso
|
|
|
|
|
|
class LeanSelfAttention(nn.Module):
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- hidden_size: int,
|
|
|
- num_attention_heads: int,
|
|
|
- attention_core: Optional[nn.Module] = None,
|
|
|
- hidden_dropout_prob: float = 0,
|
|
|
- layer_norm_eps: float = 1e-12,
|
|
|
- **kwargs,
|
|
|
+ self,
|
|
|
+ hidden_size: int,
|
|
|
+ num_attention_heads: int,
|
|
|
+ max_positions: int,
|
|
|
+ attention_core: Optional[nn.Module] = None,
|
|
|
+ layer_norm_eps: float = 1e-12,
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
"""Attention layer that does not hog GPU memory"""
|
|
|
super().__init__()
|
|
|
if attention_core is None:
|
|
|
- attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, **kwargs)
|
|
|
+ attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, max_positions, **kwargs)
|
|
|
else:
|
|
|
assert len(kwargs) == 0, f"Unexpected parameters: {kwargs}"
|
|
|
|
|
@@ -268,7 +263,6 @@ class LeanSelfAttention(nn.Module):
|
|
|
self.dense_qkv = nn.Linear(hidden_size, hidden_size * 3)
|
|
|
self.dense_out = nn.Linear(hidden_size, hidden_size)
|
|
|
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
|
|
- self.output_dropout = nn.Dropout(hidden_dropout_prob, inplace=False)
|
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
|
|
hidden_states_ln = self.layer_norm(hidden_states)
|
|
@@ -276,19 +270,25 @@ class LeanSelfAttention(nn.Module):
|
|
|
query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1)
|
|
|
attention_output, attention_probs = checkpoint(self.attention_core, query, key, value, attention_mask)
|
|
|
projected_context_layer = self.dense_out(attention_output)
|
|
|
- projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
|
|
- layernormed_context_layer = projected_context_layer_dropout + hidden_states
|
|
|
+ layernormed_context_layer = projected_context_layer + hidden_states
|
|
|
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
|
|
|
|
|
|
|
|
class SimpleAttentionCore(nn.Module):
|
|
|
- def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0):
|
|
|
+ def __init__(self, hidden_size: int, num_attention_heads: int, max_positions):
|
|
|
super().__init__()
|
|
|
assert hidden_size % num_attention_heads == 0
|
|
|
- self.attention_dropout = nn.Dropout(attention_probs_dropout, inplace=False)
|
|
|
self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads
|
|
|
self.attention_head_size = hidden_size // num_attention_heads
|
|
|
|
|
|
+ self.register_buffer(
|
|
|
+ "bias",
|
|
|
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
|
|
+ 1, 1, max_positions, max_positions
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
|
|
|
+
|
|
|
def transpose_for_scores(self, x):
|
|
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
|
x = x.view(*new_x_shape)
|
|
@@ -310,6 +310,10 @@ class SimpleAttentionCore(nn.Module):
|
|
|
attention_scores = torch.matmul(query, key.transpose(-1, -2))
|
|
|
attention_scores = attention_scores / math.sqrt(query.shape[-1])
|
|
|
|
|
|
+ query_length, key_length = query.size(-2), key.size(-2)
|
|
|
+ causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
|
|
|
+ attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
|
|
|
+
|
|
|
if attention_mask is not None:
|
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
|
attention_scores = attention_scores + attention_mask
|
|
@@ -317,10 +321,6 @@ class SimpleAttentionCore(nn.Module):
|
|
|
# Normalize the attention scores to probabilities.
|
|
|
attention_probs = torch.softmax(attention_scores, dim=-1)
|
|
|
|
|
|
- # This is actually dropping out entire tokens to attend to, which might
|
|
|
- # seem a bit unusual, but is taken from the original Transformer paper.
|
|
|
- attention_probs = self.attention_dropout(attention_probs)
|
|
|
-
|
|
|
attention_output = torch.matmul(attention_probs, value)
|
|
|
attention_output = attention_output.transpose(2, 1).flatten(2)
|
|
|
return attention_output, attention_probs
|
|
@@ -330,9 +330,14 @@ class RotaryAttentionCore(SimpleAttentionCore):
|
|
|
"""Attention core that applies rotary embeddings to queries and keys before computing dot products"""
|
|
|
|
|
|
def __init__(
|
|
|
- self, hidden_size: int, num_attention_heads: int, rotary_emb: Optional[RotaryEmbeddings] = None, **kwargs
|
|
|
+ self,
|
|
|
+ hidden_size: int,
|
|
|
+ num_attention_heads: int,
|
|
|
+ max_positions: int,
|
|
|
+ rotary_emb: Optional[RotaryEmbeddings] = None,
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
- super().__init__(hidden_size, num_attention_heads, **kwargs)
|
|
|
+ super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
|
|
|
if rotary_emb is None:
|
|
|
rotary_emb = RotaryEmbeddings(self.attention_head_size)
|
|
|
self.rotary_emb = rotary_emb
|
|
@@ -362,10 +367,7 @@ def get_attention_core(config: LeanAlbertConfig):
|
|
|
elif config.position_embedding_type == "rotary":
|
|
|
rotary_emb = RotaryEmbeddings(config.hidden_size // config.num_attention_heads, config.rotary_embedding_base)
|
|
|
return RotaryAttentionCore(
|
|
|
- config.hidden_size,
|
|
|
- config.num_attention_heads,
|
|
|
- rotary_emb,
|
|
|
- attention_probs_dropout=config.attention_probs_dropout_prob,
|
|
|
+ config.hidden_size, config.num_attention_heads, config.max_position_embeddings, rotary_emb
|
|
|
)
|
|
|
else:
|
|
|
raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding_type}")
|
|
@@ -383,7 +385,6 @@ class LeanAlbertEmbeddings(nn.Module):
|
|
|
self.position_embeddings = get_input_embedding(config)
|
|
|
|
|
|
self.layernorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
|
|
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
|
|
if self.position_embeddings is not None:
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
@@ -392,7 +393,7 @@ class LeanAlbertEmbeddings(nn.Module):
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
|
|
def forward(
|
|
|
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
|
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
|
|
):
|
|
|
if input_ids is not None:
|
|
|
input_shape = input_ids.size()
|
|
@@ -412,12 +413,11 @@ class LeanAlbertEmbeddings(nn.Module):
|
|
|
|
|
|
if self.position_embeddings is not None:
|
|
|
if position_ids is None:
|
|
|
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
|
+ position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
|
embeddings += position_embeddings
|
|
|
|
|
|
embeddings = self.layernorm(embeddings)
|
|
|
- embeddings = self.dropout(embeddings)
|
|
|
return embeddings
|
|
|
|
|
|
|
|
@@ -433,8 +433,8 @@ class LeanAlbertLayer(nn.Module):
|
|
|
self.attention = LeanSelfAttention(
|
|
|
config.hidden_size,
|
|
|
config.num_attention_heads,
|
|
|
+ config.max_position_embeddings,
|
|
|
attention_core=get_attention_core(config),
|
|
|
- hidden_dropout_prob=config.hidden_dropout_prob,
|
|
|
layer_norm_eps=config.layer_norm_eps,
|
|
|
)
|
|
|
|
|
@@ -444,7 +444,6 @@ class LeanAlbertLayer(nn.Module):
|
|
|
activation=ACT2FN[config.hidden_act],
|
|
|
gated=config.hidden_act_gated,
|
|
|
layer_norm_eps=config.layer_norm_eps,
|
|
|
- dropout=config.hidden_dropout_prob,
|
|
|
)
|
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
|
@@ -459,9 +458,10 @@ class LeanAlbertLayerGroup(AlbertLayerGroup):
|
|
|
self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
|
|
|
|
|
|
def forward(
|
|
|
- self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
|
|
|
+ self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False,
|
|
|
+ output_hidden_states=False
|
|
|
):
|
|
|
- if any(head_mask):
|
|
|
+ if head_mask is not None and any(head_mask):
|
|
|
raise NotImplementedError(f"head mask was provided, but it is not supported")
|
|
|
|
|
|
layer_hidden_states = ()
|
|
@@ -496,13 +496,13 @@ class LeanAlbertTransformer(AlbertTransformer):
|
|
|
self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
|
|
|
|
|
|
def forward(
|
|
|
- self,
|
|
|
- hidden_states,
|
|
|
- attention_mask=None,
|
|
|
- head_mask=None,
|
|
|
- output_attentions=False,
|
|
|
- output_hidden_states=False,
|
|
|
- return_dict=True,
|
|
|
+ self,
|
|
|
+ hidden_states,
|
|
|
+ attention_mask=None,
|
|
|
+ head_mask=None,
|
|
|
+ output_attentions=False,
|
|
|
+ output_hidden_states=False,
|
|
|
+ return_dict=True,
|
|
|
):
|
|
|
# TODO this should entire be replaced with inheritance and post_layer_norm
|
|
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
|
@@ -520,7 +520,7 @@ class LeanAlbertTransformer(AlbertTransformer):
|
|
|
layer_group_output = self.albert_layer_groups[group_idx](
|
|
|
hidden_states,
|
|
|
attention_mask,
|
|
|
- head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
|
|
+ None,
|
|
|
output_attentions,
|
|
|
output_hidden_states,
|
|
|
)
|
|
@@ -532,11 +532,13 @@ class LeanAlbertTransformer(AlbertTransformer):
|
|
|
if output_hidden_states:
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
+ hidden_states = self.post_layer_norm(hidden_states)
|
|
|
+
|
|
|
if not return_dict:
|
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
|
|
|
|
|
return BaseModelOutput(
|
|
|
- last_hidden_state=self.post_layer_norm(hidden_states),
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
hidden_states=all_hidden_states,
|
|
|
attentions=all_attentions,
|
|
|
)
|
|
@@ -582,24 +584,109 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
|
|
|
|
|
|
from hivemind.moe.server.layers.custom_experts import register_expert_class
|
|
|
|
|
|
-albert_sample_input = lambda batch_size, hid_dim: (
|
|
|
- torch.empty((batch_size, 512, hid_dim)),
|
|
|
- torch.ones((batch_size, 512)),
|
|
|
+head_sample_input = lambda batch_size, hid_dim: (
|
|
|
+ torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+@register_expert_class("lm_head", head_sample_input)
|
|
|
+class HeadExpert(nn.Module):
|
|
|
+ def __init__(self, hid_dim):
|
|
|
+ super().__init__()
|
|
|
+ config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
|
|
|
+ config.hidden_size = hid_dim
|
|
|
+ config.num_hidden_layers = 12
|
|
|
+
|
|
|
+ self.encoder = LeanAlbertTransformer(config)
|
|
|
+ self.embeddings = LeanAlbertEmbeddings(config)
|
|
|
+
|
|
|
+ def forward(self, input_ids):
|
|
|
+ embedding_output = self.embeddings(input_ids)
|
|
|
+ encoder_outputs, = self.encoder(embedding_output, return_dict=False)
|
|
|
+
|
|
|
+ return encoder_outputs
|
|
|
+
|
|
|
+
|
|
|
+body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, 512, hid_dim)),)
|
|
|
+
|
|
|
+
|
|
|
+@register_expert_class("lm_body", body_sample_input)
|
|
|
+class BodyExpert(nn.Module):
|
|
|
+ def __init__(self, hid_dim):
|
|
|
+ super().__init__()
|
|
|
+ config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
|
|
|
+ config.hidden_size = hid_dim
|
|
|
+ config.num_hidden_layers = 12
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ self.albert_layer_groups = nn.ModuleList(
|
|
|
+ [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
|
|
|
+ )
|
|
|
+ self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
|
|
|
+
|
|
|
+ def forward(self, hidden_states):
|
|
|
+ for i in range(self.config.num_hidden_layers):
|
|
|
+ # Index of the hidden group
|
|
|
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
|
|
+
|
|
|
+ layer_group_output = self.albert_layer_groups[group_idx](
|
|
|
+ hidden_states,
|
|
|
+ None,
|
|
|
+ None,
|
|
|
+ False,
|
|
|
+ False,
|
|
|
+ )
|
|
|
+ hidden_states = layer_group_output[0]
|
|
|
+
|
|
|
+ hidden_states = self.post_layer_norm(hidden_states)
|
|
|
+
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+
|
|
|
+tail_sample_input = lambda batch_size, hid_dim: (
|
|
|
+
|
|
|
+ torch.empty((batch_size, 512, hid_dim)),
|
|
|
+ torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
|
|
|
)
|
|
|
|
|
|
|
|
|
-@register_expert_class("albert", albert_sample_input)
|
|
|
-class LeanAlbertExpert(nn.Module):
|
|
|
+@register_expert_class("lm_tail", tail_sample_input)
|
|
|
+class TailExpert(nn.Module):
|
|
|
def __init__(self, hid_dim):
|
|
|
super().__init__()
|
|
|
config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
|
|
|
config.hidden_size = hid_dim
|
|
|
+ config.num_hidden_layers = 12
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ self.albert_layer_groups = nn.ModuleList(
|
|
|
+ [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
|
|
|
+ )
|
|
|
+ self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
|
|
|
+
|
|
|
+ self.lm_head = AlbertMLMHead(config)
|
|
|
+
|
|
|
+ def forward(self, hidden_states, labels):
|
|
|
+ for i in range(self.config.num_hidden_layers):
|
|
|
+ # Index of the hidden group
|
|
|
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
|
|
+
|
|
|
+ layer_group_output = self.albert_layer_groups[group_idx](
|
|
|
+ hidden_states,
|
|
|
+ None,
|
|
|
+ None,
|
|
|
+ False,
|
|
|
+ False,
|
|
|
+ )
|
|
|
+ hidden_states = layer_group_output[0]
|
|
|
|
|
|
- self.layer = LeanAlbertLayer(config)
|
|
|
+ hidden_states = self.post_layer_norm(hidden_states)
|
|
|
|
|
|
- def forward(self, x, attention_mask):
|
|
|
- extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
- extended_attention_mask = extended_attention_mask.to(dtype=x.dtype) # fp16 compatibility
|
|
|
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
+ lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
|
- return self.layer(x, attention_mask=extended_attention_mask)[0]
|
|
|
+ # Shift so that tokens < n predict n
|
|
|
+ shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
|
+ shift_labels = labels[..., 1:].contiguous()
|
|
|
+ # Flatten the tokens
|
|
|
+ loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
|
|
|
+ return loss
|