block.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. """
  2. Bloom intermediate layer
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. import math
  7. import torch
  8. import torch.nn.quantized.dynamic.modules.linear
  9. import torch.nn as nn
  10. from src.ops import BloomScaledSoftmax, BloomGelu
  11. from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
  12. class BloomAttention(nn.Module):
  13. def __init__(self, config, layer_number=None):
  14. super().__init__()
  15. self.hidden_size = config.hidden_size
  16. self.num_heads = config.n_head
  17. self.head_dim = self.hidden_size // self.num_heads
  18. self.split_size = self.hidden_size
  19. self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
  20. self.masked_softmax_fusion = config.masked_softmax_fusion
  21. self.hidden_dropout = config.hidden_dropout
  22. if self.head_dim * self.num_heads != self.hidden_size:
  23. raise ValueError(
  24. f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
  25. f" {self.num_heads})."
  26. )
  27. # Layer-wise attention scaling
  28. self.layer_number = max(1, layer_number)
  29. self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
  30. # Scaled Softmax
  31. self.scale_mask_softmax = BloomScaledSoftmax(
  32. self.masked_softmax_fusion,
  33. attention_mask_func,
  34. self.attention_softmax_in_fp32,
  35. self.layer_number,
  36. )
  37. if config.compression == 'qint8':
  38. self.query_key_value = nn.quantized.dynamic.modules.Linear(
  39. self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
  40. self.dense = nn.quantized.dynamic.modules.Linear(
  41. self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
  42. else:
  43. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  44. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  45. self.attention_dropout = nn.Dropout(config.attention_dropout)
  46. def forward(
  47. self,
  48. hidden_states,
  49. residual,
  50. layer_past=None,
  51. attention_mask=None,
  52. alibi=None,
  53. head_mask=None,
  54. use_cache=False,
  55. output_attentions=False,
  56. ):
  57. # hidden_states: [batch_size, seq_length, hidden_size]
  58. # repeat alibi tensor with the batch size
  59. alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)
  60. # apply preprocessing if the input is padded
  61. if attention_mask is not None and 0 in attention_mask:
  62. alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
  63. mixed_x_layer = self.query_key_value(hidden_states)
  64. # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
  65. new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
  66. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  67. # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
  68. (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
  69. if layer_past is not None:
  70. past_key, past_value = layer_past
  71. key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
  72. value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
  73. if use_cache is True:
  74. present = (key_layer, value_layer)
  75. else:
  76. present = None
  77. # [batch_size, head_dim, q_length, k_length]
  78. output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
  79. # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
  80. query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
  81. # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
  82. key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
  83. # slice alibi tensor until the query length
  84. sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
  85. # Raw attention scores. [batch_size * num_heads, q_length, k_length]
  86. beta = 1.0 / self.layer_number
  87. matmul_result = torch.baddbmm(
  88. sliced_alibi,
  89. query_layer.transpose(1, 0),
  90. key_layer.transpose(1, 0).transpose(1, 2),
  91. beta=beta,
  92. alpha=(1.0 / self.norm_factor),
  93. )
  94. # change view to [batch_size, num_heads, q_length, k_length]
  95. attention_scores = matmul_result.view(*output_size)
  96. # attention scores and attention mask [b, np, sq, sk]
  97. max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
  98. attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
  99. value_layer.dtype
  100. )
  101. attention_probs = self.attention_dropout(attention_probs)
  102. if head_mask is not None:
  103. attention_probs = attention_probs * head_mask
  104. # context layer shape: [batch_size, num_heads, q_length, head_dim]
  105. output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
  106. # change view [k_length, batch_size x num_heads, head_dim]
  107. value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
  108. # change view [batch_size x num_heads, q_length, k_length]
  109. attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
  110. # matmul: [batch_size * num_heads, q_length, head_dim]
  111. context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
  112. # change view [batch_size, num_heads, q_length, head_dim]
  113. context_layer = context_layer.view(*output_size)
  114. # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
  115. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  116. # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
  117. new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
  118. context_layer = context_layer.view(*new_context_layer_shape)
  119. # Output. [q_length, batch_size, hidden_size]
  120. # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
  121. output_tensor = self.dense(context_layer)
  122. output = output_tensor.transpose(1, 0)
  123. output = dropout_add(output, residual, self.hidden_dropout, self.training)
  124. outputs = (output, present)
  125. if output_attentions:
  126. outputs += (attention_probs,)
  127. return outputs
  128. class BloomMLP(nn.Module):
  129. def __init__(self, config):
  130. super().__init__()
  131. self.hidden_size = config.hidden_size
  132. if config.compression == 'qint8':
  133. self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
  134. self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
  135. self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
  136. 4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
  137. else:
  138. self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
  139. self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
  140. self.hidden_dropout = config.hidden_dropout
  141. self.gelu_impl = BloomGelu()
  142. def forward(self, hidden_states, residual):
  143. hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
  144. intermediate_output = self.dense_4h_to_h(hidden_states)
  145. output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
  146. return output
  147. class BloomBlock(nn.Module):
  148. def __init__(self, config, layer_number=None):
  149. super().__init__()
  150. hidden_size = config.hidden_size
  151. self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  152. self.n_head = config.n_head
  153. self.self_attention = BloomAttention(config, layer_number=layer_number)
  154. self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  155. self.mlp = BloomMLP(config)
  156. self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
  157. self.hidden_dropout = config.hidden_dropout
  158. def forward(
  159. self,
  160. hidden_states,
  161. layer_past=None,
  162. attention_mask=None,
  163. head_mask=None,
  164. use_cache=False,
  165. output_attentions=False,
  166. alibi=None,
  167. ):
  168. # hidden_states: [batch_size, seq_length, hidden_size]
  169. # Layer norm at the beginning of the transformer layer.
  170. layernorm_output = self.input_layernorm(hidden_states)
  171. # Layer norm post the self attention.
  172. if self.apply_residual_connection_post_layernorm:
  173. residual = layernorm_output
  174. else:
  175. residual = hidden_states
  176. # Self attention.
  177. attn_outputs = self.self_attention(
  178. layernorm_output,
  179. residual,
  180. layer_past=layer_past,
  181. attention_mask=attention_mask,
  182. alibi=alibi,
  183. head_mask=head_mask,
  184. use_cache=use_cache,
  185. output_attentions=output_attentions,
  186. )
  187. attention_output = attn_outputs[0]
  188. outputs = attn_outputs[1:]
  189. layernorm_output = self.post_attention_layernorm(attention_output)
  190. # Get residual
  191. if self.apply_residual_connection_post_layernorm:
  192. residual = layernorm_output
  193. else:
  194. residual = attention_output
  195. # MLP.
  196. output = self.mlp(layernorm_output, residual)
  197. if use_cache:
  198. outputs = (output,) + outputs
  199. else:
  200. outputs = (output,) + outputs[1:]
  201. return outputs # hidden_states, present, attentions