albert.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # coding=utf-8
  2. # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ALBERT modules that do not hog your GPU memory """
  16. import math
  17. from functools import lru_cache
  18. from typing import Optional
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from torch.cuda.amp import custom_bwd, custom_fwd
  23. from torch.utils.checkpoint import checkpoint, get_device_states, set_device_states
  24. from transformers import AlbertConfig
  25. from transformers.modeling_outputs import BaseModelOutput
  26. from transformers.models.albert.modeling_albert import (
  27. ACT2FN,
  28. AlbertLayerGroup,
  29. AlbertMLMHead,
  30. AlbertTransformer,
  31. )
  32. from transformers.utils import logging
  33. logger = logging.get_logger(__name__)
  34. _CONFIG_FOR_DOC = "LeanAlbertConfig"
  35. _TOKENIZER_FOR_DOC = "AlbertTokenizer"
  36. class LeanAlbertConfig(AlbertConfig):
  37. rotary_embedding_base: int = 10_000
  38. hidden_act_gated: bool = False
  39. def __hash__(self):
  40. return hash("\t".join(f"{k}={v}" for k, v in self.__dict__.items() if not k.startswith("_")))
  41. class LeanFFN(nn.Module):
  42. """
  43. A transformer FFN module that doesn't hog your GPU memory.
  44. Complete with pre-LayerNorm and residual connections.
  45. :param gated: use gated activations based on https://arxiv.org/abs/2002.05202 and https://arxiv.org/abs/2102.11972
  46. note: gated activations require 1.5x more parameters compared to their non-gated variants.
  47. """
  48. def __init__(
  49. self,
  50. hidden_size: int,
  51. intermediate_size: int,
  52. activation=F.gelu,
  53. gated: bool = False,
  54. layer_norm_eps: float = 1e-12,
  55. ):
  56. super().__init__()
  57. self.dense_i2h = nn.Linear(hidden_size, intermediate_size * 2 if gated else intermediate_size)
  58. self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
  59. self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
  60. self.activation = activation
  61. def forward(self, input):
  62. return _LeanFFN.apply(
  63. input,
  64. self.layer_norm.weight,
  65. self.layer_norm.bias,
  66. self.dense_i2h.weight,
  67. self.dense_i2h.bias,
  68. self.dense_h2o.weight,
  69. self.dense_h2o.bias,
  70. self.activation,
  71. self.training,
  72. self.layer_norm.eps,
  73. )
  74. class _LeanFFN(torch.autograd.Function):
  75. @staticmethod
  76. def _apply_activation(pre_activation: torch.Tensor, activation: callable, hid_size: int):
  77. if pre_activation.shape[-1] == hid_size:
  78. return activation(pre_activation)
  79. elif pre_activation.shape[-1] == 2 * hid_size:
  80. pre_gate, lin = pre_activation.split(pre_activation.shape[-1] // 2, dim=-1)
  81. return activation(pre_gate).mul_(lin)
  82. else:
  83. raise RuntimeError("The output size of FFN layer must be either 1x or 2x the intermediate_size.")
  84. @staticmethod
  85. @custom_fwd
  86. def forward(
  87. ctx,
  88. input,
  89. ln_weight,
  90. ln_bias,
  91. i2h_weight,
  92. i2h_bias,
  93. h2o_weight,
  94. h2o_bias,
  95. activation,
  96. training,
  97. ln_eps,
  98. ):
  99. ctx._activation, ctx._training, ctx._ln_eps = activation, training, ln_eps
  100. ctx._cpu_rng_state = torch.get_rng_state()
  101. ctx._device_rng_states = get_device_states(input)
  102. input_2d = input.view(-1, input.shape[-1])
  103. input_ln = F.layer_norm(input_2d, input.shape[-1:], ln_weight, ln_bias, ln_eps)
  104. pre_activation = F.linear(input_ln, i2h_weight, i2h_bias)
  105. hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, h2o_weight.shape[1])
  106. out = F.linear(hid_act, h2o_weight, h2o_bias)
  107. out = out.add_(input_2d)
  108. ctx.save_for_backward(input, pre_activation, ln_weight, ln_bias, i2h_weight, h2o_weight)
  109. return out.view(*input.shape)
  110. @staticmethod
  111. @custom_bwd
  112. def backward(ctx, grad_output):
  113. grad_input = grad_ln_weight = grad_ln_bias = None
  114. grad_i2h_weight = grad_i2h_bias = grad_h2o_weight = grad_h2o_bias = None
  115. input, pre_activation, ln_weight, ln_bias, i2h_weight, h2o_weight = ctx.saved_tensors
  116. torch.set_rng_state(ctx._cpu_rng_state)
  117. set_device_states(*ctx._device_rng_states)
  118. input_2d = input.view(-1, input.shape[-1])
  119. grad_output_2d = grad_output.view(-1, grad_output.shape[-1])
  120. grad_hid_act = torch.mm(grad_output_2d, h2o_weight)
  121. with torch.enable_grad():
  122. # rematerialize activation
  123. pre_activation.requires_grad_(True)
  124. hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, h2o_weight.shape[1])
  125. (grad_hid,) = torch.autograd.grad(hid_act, pre_activation, grad_hid_act)
  126. pre_activation.requires_grad_(False)
  127. grad_input_ln_2d = torch.mm(grad_hid, i2h_weight)
  128. with torch.enable_grad():
  129. # rematerialize input_ln
  130. input_2d.requires_grad_(True)
  131. input_ln_2d = F.layer_norm(input_2d, input.shape[-1:], ln_weight, ln_bias, ctx._ln_eps)
  132. if any(ctx.needs_input_grad[0:3]):
  133. grad_input_2d, grad_ln_weight, grad_ln_bias = torch.autograd.grad(
  134. outputs=input_ln_2d, inputs=[input_2d, ln_weight, ln_bias], grad_outputs=grad_input_ln_2d
  135. )
  136. input_2d.requires_grad_(False)
  137. input_ln_2d = input_ln_2d.detach_()
  138. if ctx.needs_input_grad[0]:
  139. grad_input_2d = grad_input_2d.add_(grad_output_2d)
  140. grad_input = grad_input_2d.view(*grad_output.shape)
  141. if ctx.needs_input_grad[3]:
  142. grad_i2h_weight = grad_hid.t().mm(input_ln_2d)
  143. if ctx.needs_input_grad[4]:
  144. grad_i2h_bias = grad_hid.sum(0)
  145. if ctx.needs_input_grad[5]:
  146. grad_h2o_weight = grad_output_2d.t().mm(hid_act)
  147. if ctx.needs_input_grad[6]:
  148. grad_h2o_bias = grad_output_2d.sum(0)
  149. return (
  150. grad_input,
  151. grad_ln_weight,
  152. grad_ln_bias,
  153. grad_i2h_weight,
  154. grad_i2h_bias,
  155. grad_h2o_weight,
  156. grad_h2o_bias,
  157. None,
  158. None,
  159. None,
  160. None,
  161. )
  162. class RotaryEmbeddings(nn.Module):
  163. """Applies rotary position embeddings to a tensor, uses caching to improve performance"""
  164. def __init__(self, dim: int, base: int = 10_000):
  165. super().__init__()
  166. self.dim, self.base = dim, base
  167. def forward(self, x: torch.Tensor, offset: int = 0):
  168. """
  169. :param x: tensor of shape [batch_size, seq_len, nhead, hid_size]
  170. :param offset: add this value to all position indices
  171. """
  172. seq_len = x.shape[1]
  173. cos, sin = getattr(self, "cos", None), getattr(self, "sin", None)
  174. if cos is None or seq_len + offset >= cos.shape[0] or x.dtype != cos.dtype or x.device != cos.device:
  175. cos, sin = get_auxiliary_tensors(seq_len + offset, self.dim, x.dtype, x.device, self.base)
  176. self.register_buffer("cos", cos)
  177. self.register_buffer("sin", sin)
  178. return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
  179. @torch.no_grad()
  180. @torch.jit.script
  181. def get_auxiliary_tensors(seq_len: int, dim: int, dtype: torch.dtype, device: torch.device, base: int):
  182. """
  183. Compute auxiliary sine and cosine tensors for rotary position embedding
  184. :returns: a tuple of (cos, sin) tensors of shape [seq_len, hid_size]
  185. """
  186. _buf = torch.linspace(0, -1 + 2 / dim, dim // 2, dtype=torch.float32, device=device)
  187. inv_freq = torch.pow(base, _buf, out=_buf).repeat(2)
  188. time_ix = torch.arange(seq_len, dtype=inv_freq.dtype, device=device)
  189. freqs = time_ix[:, None] * inv_freq[None, :]
  190. cos = torch.cos(freqs)
  191. sin = torch.sin(freqs, out=freqs)
  192. return cos.to(dtype), sin.to(dtype)
  193. @torch.jit.script
  194. def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
  195. """rotate pairwise coordinate using precomputed cos & sin tensors"""
  196. dim = x.shape[-1]
  197. x_left, x_right = x.split(split_size=dim // 2, dim=x.ndim - 1)
  198. x_rotated = torch.cat([x_right.neg(), x_left], dim=x.ndim - 1)
  199. return x * cos + x_rotated * sin
  200. class LeanSelfAttention(nn.Module):
  201. def __init__(
  202. self,
  203. hidden_size: int,
  204. num_attention_heads: int,
  205. max_positions: int,
  206. attention_core: Optional[nn.Module] = None,
  207. layer_norm_eps: float = 1e-12,
  208. **kwargs,
  209. ):
  210. """Attention layer that does not hog GPU memory"""
  211. super().__init__()
  212. if attention_core is None:
  213. attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, max_positions, **kwargs)
  214. else:
  215. assert len(kwargs) == 0, f"Unexpected parameters: {kwargs}"
  216. self.hidden_size = hidden_size
  217. self.attention_core = attention_core
  218. self.dense_qkv = nn.Linear(hidden_size, hidden_size * 3)
  219. self.dense_out = nn.Linear(hidden_size, hidden_size)
  220. self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
  221. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  222. hidden_states_ln = self.layer_norm(hidden_states)
  223. qkv_output = self.dense_qkv(hidden_states_ln)
  224. query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1)
  225. attention_output, attention_probs = checkpoint(self.attention_core, query, key, value, attention_mask)
  226. projected_context_layer = self.dense_out(attention_output)
  227. layernormed_context_layer = projected_context_layer + hidden_states
  228. return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
  229. class SimpleAttentionCore(nn.Module):
  230. def __init__(self, hidden_size: int, num_attention_heads: int, max_positions):
  231. super().__init__()
  232. assert hidden_size % num_attention_heads == 0
  233. self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads
  234. self.attention_head_size = hidden_size // num_attention_heads
  235. self.register_buffer(
  236. "bias",
  237. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
  238. 1, 1, max_positions, max_positions
  239. ),
  240. )
  241. self.register_buffer("masked_bias", torch.tensor(-1e4))
  242. def transpose_for_scores(self, x):
  243. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  244. x = x.view(*new_x_shape)
  245. return x.permute(0, 2, 1, 3)
  246. def forward(self, query, key, value, attention_mask):
  247. """
  248. :param query: [batch_size, query_seq_len, hidden_size]
  249. :param key: [batch_size, kv_seq_len, hidden_size]
  250. :param value: [batch_size, kv_seq_len, hidden_size]
  251. :param attention_mask: [batch, query_seq_len, hidden_size]
  252. :return: (outputs, probs)
  253. - outputs shape: [batch_size, query_seq_len, hidden_size]
  254. - probs shape: [batch_size, num_heads, query_seq_len, kv_seq_len]
  255. """
  256. query, key, value = map(self.transpose_for_scores, (query, key, value))
  257. # Take the dot product between "query" and "key" to get the raw attention scores.
  258. attention_scores = torch.matmul(query, key.transpose(-1, -2))
  259. attention_scores = attention_scores / math.sqrt(query.shape[-1])
  260. query_length, key_length = query.size(-2), key.size(-2)
  261. causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
  262. attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
  263. if attention_mask is not None:
  264. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  265. attention_scores = attention_scores + attention_mask
  266. # Normalize the attention scores to probabilities.
  267. attention_probs = torch.softmax(attention_scores, dim=-1)
  268. attention_output = torch.matmul(attention_probs, value)
  269. attention_output = attention_output.transpose(2, 1).flatten(2)
  270. return attention_output, attention_probs
  271. class RotaryAttentionCore(SimpleAttentionCore):
  272. """Attention core that applies rotary embeddings to queries and keys before computing dot products"""
  273. def __init__(
  274. self,
  275. hidden_size: int,
  276. num_attention_heads: int,
  277. max_positions: int,
  278. rotary_emb: Optional[RotaryEmbeddings] = None,
  279. **kwargs,
  280. ):
  281. super().__init__(hidden_size, num_attention_heads, max_positions, **kwargs)
  282. if rotary_emb is None:
  283. rotary_emb = RotaryEmbeddings(self.attention_head_size)
  284. self.rotary_emb = rotary_emb
  285. def rotate(self, tensor: torch.Tensor):
  286. """:param tensor: query or key, shape: [batch_size, query_seq_len, hidden_size]"""
  287. tensor_split_heads = tensor.view(*(tensor.shape[:-1] + (self.num_attention_heads, self.attention_head_size)))
  288. return self.rotary_emb(tensor_split_heads).view(*tensor.shape)
  289. def forward(self, query, key, value, attention_mask):
  290. return super().forward(self.rotate(query), self.rotate(key), value, attention_mask)
  291. def get_input_embedding(config: LeanAlbertConfig):
  292. if config.position_embedding_type == "absolute":
  293. return nn.Embedding(config.max_position_embeddings, config.embedding_size)
  294. elif config.position_embedding_type == "rotary":
  295. return None
  296. else:
  297. raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding}")
  298. @lru_cache()
  299. def get_attention_core(config: LeanAlbertConfig):
  300. if config.position_embedding_type == "absolute":
  301. return None
  302. elif config.position_embedding_type == "rotary":
  303. rotary_emb = RotaryEmbeddings(config.hidden_size // config.num_attention_heads, config.rotary_embedding_base)
  304. return RotaryAttentionCore(
  305. config.hidden_size, config.num_attention_heads, config.max_position_embeddings, rotary_emb
  306. )
  307. else:
  308. raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding_type}")
  309. class LeanAlbertEmbeddings(nn.Module):
  310. """
  311. Construct the embeddings from word, position and token_type embeddings.
  312. """
  313. def __init__(self, config):
  314. super().__init__()
  315. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  316. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  317. self.position_embeddings = get_input_embedding(config)
  318. self.layernorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  319. if self.position_embeddings is not None:
  320. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  321. self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
  322. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  323. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
  324. def forward(
  325. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  326. ):
  327. if input_ids is not None:
  328. input_shape = input_ids.size()
  329. else:
  330. input_shape = inputs_embeds.size()[:-1]
  331. seq_length = input_shape[1]
  332. if token_type_ids is None:
  333. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  334. if inputs_embeds is None:
  335. inputs_embeds = self.word_embeddings(input_ids)
  336. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  337. embeddings = inputs_embeds + token_type_embeddings
  338. if self.position_embeddings is not None:
  339. if position_ids is None:
  340. position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
  341. position_embeddings = self.position_embeddings(position_ids)
  342. embeddings += position_embeddings
  343. embeddings = self.layernorm(embeddings)
  344. return embeddings
  345. class LeanAlbertLayer(nn.Module):
  346. def __init__(self, config: LeanAlbertConfig):
  347. super().__init__()
  348. self.config = config
  349. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  350. self.seq_len_dim = 1
  351. self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  352. self.attention = LeanSelfAttention(
  353. config.hidden_size,
  354. config.num_attention_heads,
  355. config.max_position_embeddings,
  356. attention_core=get_attention_core(config),
  357. layer_norm_eps=config.layer_norm_eps,
  358. )
  359. self.ffn = LeanFFN(
  360. config.hidden_size,
  361. config.intermediate_size,
  362. activation=ACT2FN[config.hidden_act],
  363. gated=config.hidden_act_gated,
  364. layer_norm_eps=config.layer_norm_eps,
  365. )
  366. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  367. attention_output, *extras = self.attention(hidden_states, attention_mask, output_attentions)
  368. ffn_output = self.ffn(attention_output)
  369. return (ffn_output, attention_output, *extras)
  370. class LeanAlbertLayerGroup(AlbertLayerGroup):
  371. def __init__(self, config):
  372. nn.Module.__init__(self)
  373. self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
  374. def forward(
  375. self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
  376. ):
  377. if head_mask is not None and any(head_mask):
  378. raise NotImplementedError(f"head mask was provided, but it is not supported")
  379. layer_hidden_states = ()
  380. layer_attentions = ()
  381. for layer_index, albert_layer in enumerate(self.albert_layers):
  382. layer_output = albert_layer(hidden_states, attention_mask, output_attentions)
  383. hidden_states = layer_output[0]
  384. if output_attentions:
  385. layer_attentions = layer_attentions + (layer_output[1],)
  386. if output_hidden_states:
  387. layer_hidden_states = layer_hidden_states + (hidden_states,)
  388. outputs = (hidden_states,)
  389. if output_hidden_states:
  390. outputs = outputs + (layer_hidden_states,)
  391. if output_attentions:
  392. outputs = outputs + (layer_attentions,)
  393. return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
  394. class LeanAlbertTransformer(AlbertTransformer):
  395. def __init__(self, config):
  396. nn.Module.__init__(self)
  397. self.config = config
  398. self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
  399. self.albert_layer_groups = nn.ModuleList(
  400. [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
  401. )
  402. self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
  403. def forward(
  404. self,
  405. hidden_states,
  406. attention_mask=None,
  407. head_mask=None,
  408. output_attentions=False,
  409. output_hidden_states=False,
  410. return_dict=True,
  411. ):
  412. # TODO this should entire be replaced with inheritance and post_layer_norm
  413. hidden_states = self.embedding_hidden_mapping_in(hidden_states)
  414. all_hidden_states = (hidden_states,) if output_hidden_states else None
  415. all_attentions = () if output_attentions else None
  416. for i in range(self.config.num_hidden_layers):
  417. # Number of layers in a hidden group
  418. layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
  419. # Index of the hidden group
  420. group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
  421. layer_group_output = self.albert_layer_groups[group_idx](
  422. hidden_states,
  423. attention_mask,
  424. None,
  425. output_attentions,
  426. output_hidden_states,
  427. )
  428. hidden_states = layer_group_output[0]
  429. if output_attentions:
  430. all_attentions = all_attentions + layer_group_output[-1]
  431. if output_hidden_states:
  432. all_hidden_states = all_hidden_states + (hidden_states,)
  433. hidden_states = self.post_layer_norm(hidden_states)
  434. if not return_dict:
  435. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  436. return BaseModelOutput(
  437. last_hidden_state=hidden_states,
  438. hidden_states=all_hidden_states,
  439. attentions=all_attentions,
  440. )
  441. from hivemind.moe.server.layers.custom_experts import register_expert_class
  442. SEQUENCE_LENGTH = 2048
  443. head_sample_input = lambda batch_size, hid_dim: (
  444. torch.randint(low=0, high=1000, size=(batch_size, SEQUENCE_LENGTH), dtype=torch.long),
  445. )
  446. @register_expert_class("lm_head", head_sample_input)
  447. class HeadExpert(nn.Module):
  448. def __init__(self, hid_dim):
  449. super().__init__()
  450. config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
  451. config.hidden_size = hid_dim
  452. config.intermediate_size = 4 * config.hidden_size
  453. config.num_hidden_layers = 12
  454. config.vocab_size = 50304
  455. config.max_position_embeddings = SEQUENCE_LENGTH
  456. self.encoder = LeanAlbertTransformer(config)
  457. self.embeddings = LeanAlbertEmbeddings(config)
  458. def forward(self, input_ids):
  459. embedding_output = self.embeddings(input_ids)
  460. (encoder_outputs,) = self.encoder(embedding_output, return_dict=False)
  461. return encoder_outputs
  462. body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)
  463. @register_expert_class("lm_body", body_sample_input)
  464. class BodyExpert(nn.Module):
  465. def __init__(self, hid_dim):
  466. super().__init__()
  467. config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
  468. config.hidden_size = hid_dim
  469. config.intermediate_size = 4 * config.hidden_size
  470. config.num_hidden_layers = 12
  471. config.vocab_size = 50304
  472. config.max_position_embeddings = SEQUENCE_LENGTH
  473. self.config = config
  474. self.albert_layer_groups = nn.ModuleList(
  475. [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
  476. )
  477. self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
  478. def forward(self, hidden_states):
  479. for i in range(self.config.num_hidden_layers):
  480. # Index of the hidden group
  481. group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
  482. layer_group_output = self.albert_layer_groups[group_idx](
  483. hidden_states,
  484. None,
  485. None,
  486. False,
  487. False,
  488. )
  489. hidden_states = layer_group_output[0]
  490. hidden_states = self.post_layer_norm(hidden_states)
  491. return hidden_states
  492. tail_sample_input = lambda batch_size, hid_dim: (
  493. torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),
  494. torch.randint(0, 1000, (batch_size, SEQUENCE_LENGTH), dtype=torch.long),
  495. )
  496. @register_expert_class("lm_tail", tail_sample_input)
  497. class TailExpert(nn.Module):
  498. def __init__(self, hid_dim):
  499. super().__init__()
  500. config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
  501. config.hidden_size = hid_dim
  502. config.intermediate_size = 4 * config.hidden_size
  503. config.num_hidden_layers = 12
  504. config.vocab_size = 50304
  505. config.max_position_embeddings = SEQUENCE_LENGTH
  506. self.config = config
  507. self.albert_layer_groups = nn.ModuleList(
  508. [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
  509. )
  510. self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
  511. self.lm_head = AlbertMLMHead(config)
  512. def forward(self, hidden_states, labels):
  513. for i in range(self.config.num_hidden_layers):
  514. # Index of the hidden group
  515. group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
  516. layer_group_output = self.albert_layer_groups[group_idx](
  517. hidden_states,
  518. None,
  519. None,
  520. False,
  521. False,
  522. )
  523. hidden_states = layer_group_output[0]
  524. hidden_states = self.post_layer_norm(hidden_states)
  525. lm_logits = self.lm_head(hidden_states)
  526. # Shift so that tokens < n predict n
  527. shift_logits = lm_logits[..., :-1, :].contiguous()
  528. shift_labels = labels[..., 1:].contiguous()
  529. # Flatten the tokens
  530. loss = F.cross_entropy(shift_logits.permute(0, 2, 1), shift_labels, reduction="none")
  531. return loss