block.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. """
  2. Bloom intermediate layer
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. import math
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.quantized.dynamic.modules.linear
  10. from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
  11. pre_process_alibi_for_pad, split_tensor_along_last_dim)
  12. class BloomAttention(nn.Module):
  13. def __init__(self, config, layer_number=None):
  14. super().__init__()
  15. self.hidden_size = config.hidden_size
  16. self.num_heads = config.n_head
  17. self.head_dim = self.hidden_size // self.num_heads
  18. self.split_size = self.hidden_size
  19. self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
  20. self.masked_softmax_fusion = config.masked_softmax_fusion
  21. self.hidden_dropout = config.hidden_dropout
  22. if self.head_dim * self.num_heads != self.hidden_size:
  23. raise ValueError(
  24. f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
  25. f" {self.num_heads})."
  26. )
  27. # Layer-wise attention scaling
  28. self.layer_number = max(1, layer_number)
  29. self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
  30. # Scaled Softmax
  31. self.scale_mask_softmax = BloomScaledSoftmax(
  32. self.masked_softmax_fusion,
  33. attention_mask_func,
  34. self.attention_softmax_in_fp32,
  35. self.layer_number,
  36. )
  37. if config.compression == "qint8":
  38. self.query_key_value = nn.quantized.dynamic.modules.Linear(
  39. self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
  40. )
  41. self.dense = nn.quantized.dynamic.modules.Linear(
  42. self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
  43. )
  44. else:
  45. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  46. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  47. self.attention_dropout = nn.Dropout(config.attention_dropout)
  48. def forward(
  49. self,
  50. hidden_states,
  51. residual,
  52. layer_past=None,
  53. attention_mask=None,
  54. alibi=None,
  55. head_mask=None,
  56. use_cache=False,
  57. output_attentions=False,
  58. ):
  59. if alibi is None:
  60. current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
  61. alibi = build_alibi_tensor(
  62. current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
  63. )
  64. # hidden_states: [batch_size, seq_length, hidden_size]
  65. # apply preprocessing if the input is padded
  66. if attention_mask is not None:
  67. alibi = pre_process_alibi_for_pad(alibi, attention_mask)
  68. # otherwise repeat alibi tensor with the batch size
  69. else:
  70. alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
  71. mixed_x_layer = self.query_key_value(hidden_states)
  72. # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
  73. new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
  74. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  75. # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
  76. (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
  77. if layer_past is not None:
  78. past_key, past_value = layer_past
  79. key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
  80. value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
  81. if use_cache is True:
  82. present = (key_layer, value_layer)
  83. else:
  84. present = None
  85. # [batch_size, head_dim, q_length, k_length]
  86. output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
  87. # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
  88. query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
  89. # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
  90. key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
  91. # Raw attention scores. [batch_size * num_heads, q_length, k_length]
  92. beta = 1.0 / self.layer_number
  93. matmul_result = torch.baddbmm(
  94. alibi,
  95. query_layer.transpose(1, 0),
  96. key_layer.transpose(1, 0).transpose(1, 2),
  97. beta=beta,
  98. alpha=(1.0 / self.norm_factor),
  99. )
  100. # change view to [batch_size, num_heads, q_length, k_length]
  101. attention_scores = matmul_result.view(*output_size)
  102. # attention scores and attention mask [b, np, sq, sk]
  103. max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
  104. attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
  105. attention_probs = self.attention_dropout(attention_probs)
  106. if head_mask is not None:
  107. attention_probs = attention_probs * head_mask
  108. # context layer shape: [batch_size, num_heads, q_length, head_dim]
  109. output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
  110. # change view [k_length, batch_size x num_heads, head_dim]
  111. value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
  112. # change view [batch_size x num_heads, q_length, k_length]
  113. attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
  114. # matmul: [batch_size * num_heads, q_length, head_dim]
  115. context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
  116. # change view [batch_size, num_heads, q_length, head_dim]
  117. context_layer = context_layer.view(*output_size)
  118. # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
  119. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  120. # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
  121. new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
  122. context_layer = context_layer.view(*new_context_layer_shape)
  123. # Output. [q_length, batch_size, hidden_size]
  124. # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
  125. output_tensor = self.dense(context_layer)
  126. output = output_tensor.transpose(1, 0)
  127. output = dropout_add(output, residual, self.hidden_dropout, self.training)
  128. outputs = (output, present)
  129. if output_attentions:
  130. outputs += (attention_probs,)
  131. return outputs
  132. class BloomMLP(nn.Module):
  133. def __init__(self, config):
  134. super().__init__()
  135. self.hidden_size = config.hidden_size
  136. if config.compression == "qint8":
  137. self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
  138. self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
  139. )
  140. self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
  141. 4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
  142. )
  143. else:
  144. self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
  145. self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
  146. self.hidden_dropout = config.hidden_dropout
  147. self.gelu_impl = BloomGelu()
  148. def forward(self, hidden_states, residual):
  149. hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
  150. intermediate_output = self.dense_4h_to_h(hidden_states)
  151. output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
  152. return output
  153. class BloomBlock(nn.Module):
  154. def __init__(self, config, layer_number=None):
  155. super().__init__()
  156. self.hidden_size = config.hidden_size
  157. self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
  158. self.n_head = config.n_head
  159. self.self_attention = BloomAttention(config, layer_number=layer_number)
  160. self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
  161. self.mlp = BloomMLP(config)
  162. self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
  163. self.hidden_dropout = config.hidden_dropout
  164. def forward(
  165. self,
  166. hidden_states,
  167. layer_past=None,
  168. attention_mask=None,
  169. head_mask=None,
  170. use_cache=False,
  171. output_attentions=False,
  172. alibi=None,
  173. ):
  174. # hidden_states: [batch_size, seq_length, hidden_size]
  175. # Layer norm at the beginning of the transformer layer.
  176. layernorm_output = self.input_layernorm(hidden_states)
  177. # Layer norm post the self attention.
  178. if self.apply_residual_connection_post_layernorm:
  179. residual = layernorm_output
  180. else:
  181. residual = hidden_states
  182. # Self attention.
  183. attn_outputs = self.self_attention(
  184. layernorm_output,
  185. residual,
  186. layer_past=layer_past,
  187. attention_mask=attention_mask,
  188. alibi=alibi,
  189. head_mask=head_mask,
  190. use_cache=use_cache,
  191. output_attentions=output_attentions,
  192. )
  193. attention_output = attn_outputs[0]
  194. outputs = attn_outputs[1:]
  195. layernorm_output = self.post_attention_layernorm(attention_output)
  196. # Get residual
  197. if self.apply_residual_connection_post_layernorm:
  198. residual = layernorm_output
  199. else:
  200. residual = attention_output
  201. # MLP.
  202. output = self.mlp(layernorm_output, residual)
  203. if use_cache:
  204. outputs = (output,) + outputs
  205. else:
  206. outputs = (output,) + outputs[1:]
  207. return outputs # hidden_states, present, attentions