ops.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """
  2. Utility operations used in the the BLOOM model
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. import math
  7. import torch
  8. import torch.autograd
  9. from torch import nn
  10. def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
  11. """Split a tensor along its last dimension.
  12. Args:
  13. tensor: ([`torch.tensor`], *required*):
  14. input tensor to split
  15. num_partitions ([`int`], *required*):
  16. number of partitions to split the tensor
  17. contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
  18. If True, make each chunk contiguous in memory.
  19. """
  20. # Get the size and dimension.
  21. last_dim = tensor.dim() - 1
  22. numerator, denominator = tensor.size()[last_dim], num_partitions
  23. if not (numerator % denominator == 0):
  24. raise ValueError(f"{numerator} is not divisible by {denominator}")
  25. last_dim_size = numerator // denominator
  26. # Split.
  27. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  28. # Note: torch.split does not create contiguous tensors by default.
  29. if contiguous_split_chunks:
  30. return tuple(chunk.contiguous() for chunk in tensor_list)
  31. return tensor_list
  32. def attention_mask_func(attention_scores, attention_mask, causal_mask):
  33. if attention_mask.dtype == torch.bool:
  34. attention_mask_bool = ~attention_mask
  35. else:
  36. attention_mask_bool = (1 - attention_mask).bool()
  37. query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
  38. padded_causal_mask = (
  39. attention_mask_bool[:, None, key_length - query_length : key_length, None]
  40. + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
  41. ).bool()
  42. padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
  43. # Make use of floats
  44. return (
  45. attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
  46. padded_causal_mask,
  47. )
  48. def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
  49. """
  50. Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  51. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  52. `softmax(l+a) = softmax(l)`. Based on
  53. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  54. Args:
  55. Returns tensor shaped (n_head, 1, max_seq_len)
  56. max_seq_len: (`int`, *required*):
  57. max sequence length
  58. n_head: (`int`, *required*):
  59. number of heads
  60. dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  61. dtype of the output tensor
  62. """
  63. def get_slopes(n):
  64. def get_slopes_power_of_2(n):
  65. start = 2 ** (-(2 ** -(math.log2(n) - 3)))
  66. ratio = start
  67. return [start * ratio**i for i in range(n)]
  68. if math.log2(n).is_integer():
  69. return get_slopes_power_of_2(n)
  70. else:
  71. closest_power_of_2 = 2 ** math.floor(math.log2(n))
  72. return (
  73. get_slopes_power_of_2(closest_power_of_2)
  74. + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
  75. )
  76. slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
  77. arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
  78. alibi = slopes * arange_tensor.expand(n_head, -1, -1)
  79. alibi = alibi.to(dtype)
  80. return alibi
  81. def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
  82. """
  83. Args:
  84. Pre-process the alibi tensor for padding.
  85. alibi: ([`torch.tensor`], *required*):
  86. alibi tensor to pre-process
  87. attention_mask: ([`torch.tensor`], *required*):
  88. attention mask to pre-process"""
  89. # Sanity check if we are not inferring less tokens than the total sequence length
  90. # This usually happens when the inference is done with past_key_values
  91. # In this case we re-create the alibi tensor with the correct sequence length
  92. if attention_mask.shape[-1] != alibi.shape[-1]:
  93. alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
  94. attention_mask.shape[0], 1, 1
  95. )
  96. # Get the indexes of the padding tokens
  97. index_x0, index_y0 = torch.where(attention_mask == 0.0)
  98. index_x1, index_y1 = torch.where(attention_mask == 1.0)
  99. # Clone the embeddings - we can detach because the embeddings are not learned
  100. # Get a refence tensor
  101. slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
  102. # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
  103. # Only where you do not have padding. Replace padding tokens by zeros
  104. # This operation can be seen as a shifting operation.
  105. for i, index in enumerate(torch.unique(index_x0)):
  106. slice_to_modify = torch.zeros_like(slice_reference_alibi)
  107. index_shift = index_y1[index_x1 == index]
  108. shift_value = len(index_shift)
  109. slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
  110. alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
  111. return alibi
  112. def dropout_add(x, residual, prob, training):
  113. """
  114. Dropout add function
  115. Args:
  116. x (`torch.tensor`, *required*):
  117. input tensor
  118. residual (`torch.tensor`, *rquired*):
  119. esidual tensor
  120. prob (`float`, *required*):
  121. dropout probability
  122. training (`bool`, *required*):
  123. training mode
  124. """
  125. out = nn.functional.dropout(x, p=prob, training=training)
  126. out = residual + out
  127. return out
  128. def bloom_gelu_forward(x):
  129. """
  130. Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
  131. make the model jitable.
  132. Args:
  133. x (`torch.tensor`, *required*):
  134. input hidden states
  135. """
  136. return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
  137. def bloom_gelu_back(g, x):
  138. """
  139. gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
  140. 0.3989423 * x * torch.exp(-0.5 * x * x)
  141. Args:
  142. g (`torch.tensor`, *required*):
  143. gradient output tensor
  144. x (`torch.tensor`, *required*):
  145. input tensor
  146. """
  147. x = x[0] # x is a tuple of 1 element, needs to unpack it first
  148. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  149. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  150. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
  151. return ff * g
  152. class GeLUFunction(torch.autograd.Function):
  153. @staticmethod
  154. def forward(ctx, input):
  155. ctx.save_for_backward(input)
  156. return bloom_gelu_forward(input)
  157. @staticmethod
  158. def backward(ctx, grad_output):
  159. input = ctx.saved_tensors
  160. tmp = bloom_gelu_back(grad_output, input)
  161. return tmp
  162. class BloomGelu(nn.Module):
  163. """
  164. BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
  165. torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
  166. copied from Megatron-DeepSpeed code and adapted for our needs
  167. See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
  168. """
  169. def __init__(self):
  170. super().__init__()
  171. def forward(self, x):
  172. if self.training:
  173. return GeLUFunction.apply(x)
  174. else:
  175. return bloom_gelu_forward(x)
  176. class BloomScaledSoftmax(nn.Module):
  177. """
  178. fused operation: scaling + mask + softmax
  179. Args:
  180. input_in_fp16 (`bool`, *required*):
  181. flag to indicate if input in fp16 data format.
  182. input_in_bf16 (`bool`, *required*):
  183. flag to indicate if input in bf16 data format.
  184. scaled_masked_softmax_fusion (`bool`, *required*):
  185. flag to indicate user want to use softmax fusion
  186. mask_func (`function`, *required*):
  187. mask function to be applied.
  188. softmax_in_fp32 (`bool`, *required*):
  189. if true, softmax in performed at fp32 precision.
  190. scale (`float`, *required*):
  191. scaling factor used in input tensor scaling.
  192. """
  193. def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
  194. super().__init__()
  195. self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
  196. self.mask_func = mask_func
  197. self.softmax_in_fp32 = softmax_in_fp32
  198. self.scale = scale
  199. if not (self.scale is None or softmax_in_fp32):
  200. raise ValueError("softmax should be in fp32 when scaled")
  201. def forward(self, input, mask, max_positions):
  202. input_dtype = input.dtype
  203. input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
  204. softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
  205. if self.scale is not None:
  206. input = input * self.scale
  207. if mask is not None:
  208. mask = mask.to(input.device)
  209. causal_mask = (
  210. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
  211. .view(1, 1, max_positions, max_positions)
  212. .to(input.device)
  213. )
  214. mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
  215. probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
  216. else:
  217. probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
  218. if input_in_16bit and self.softmax_in_fp32:
  219. probs = probs.to(dtype=input_dtype)
  220. return probs