ops.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """
  2. Utility operations used in the the BLOOM model
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. import math
  7. import torch
  8. import torch.autograd
  9. import torch.nn.functional as F
  10. from torch import nn
  11. def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
  12. """Split a tensor along its last dimension.
  13. Args:
  14. tensor: ([`torch.tensor`], *required*):
  15. input tensor to split
  16. num_partitions ([`int`], *required*):
  17. number of partitions to split the tensor
  18. contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
  19. If True, make each chunk contiguous in memory.
  20. """
  21. # Get the size and dimension.
  22. last_dim = tensor.dim() - 1
  23. numerator, denominator = tensor.size()[last_dim], num_partitions
  24. if not (numerator % denominator == 0):
  25. raise ValueError(f"{numerator} is not divisible by {denominator}")
  26. last_dim_size = numerator // denominator
  27. # Split.
  28. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  29. # Note: torch.split does not create contiguous tensors by default.
  30. if contiguous_split_chunks:
  31. return tuple(chunk.contiguous() for chunk in tensor_list)
  32. return tensor_list
  33. def attention_mask_func(attention_scores, attention_mask, causal_mask):
  34. if attention_mask.dtype == torch.bool:
  35. attention_mask_bool = ~attention_mask
  36. else:
  37. attention_mask_bool = (1 - attention_mask).bool()
  38. query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
  39. padded_causal_mask = (
  40. attention_mask_bool[:, None, key_length - query_length : key_length, None]
  41. + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
  42. ).bool()
  43. padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
  44. # Make use of floats
  45. return (
  46. attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
  47. padded_causal_mask,
  48. )
  49. def build_alibi_tensor(
  50. max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
  51. ) -> torch.Tensor:
  52. """
  53. Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  54. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  55. `softmax(l+a) = softmax(l)`. Based on
  56. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  57. Args:
  58. Returns tensor shaped (n_head, 1, max_seq_len)
  59. max_seq_len: (`int`, *required*):
  60. max sequence length
  61. n_head: (`int`, *required*):
  62. number of heads
  63. dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  64. dtype of the output tensor
  65. device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
  66. device of the output alibi tensor
  67. """
  68. closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
  69. base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
  70. powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
  71. slopes = torch.pow(base, powers)
  72. if closest_power_of_2 != n_head:
  73. extra_base = torch.tensor(
  74. 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
  75. )
  76. num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
  77. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
  78. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  79. lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
  80. return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
  81. def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
  82. """
  83. Args:
  84. Pre-process the alibi tensor for padding.
  85. alibi: ([`torch.tensor`], *required*):
  86. alibi tensor to pre-process
  87. attention_mask: ([`torch.tensor`], *required*):
  88. attention mask to pre-process
  89. """
  90. assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
  91. unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
  92. # ^-- [batch, max_len], values correspond to element indices after removing padding
  93. # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
  94. alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
  95. return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
  96. def dropout_add(x, residual, prob, training):
  97. """
  98. Dropout add function
  99. Args:
  100. x (`torch.tensor`, *required*):
  101. input tensor
  102. residual (`torch.tensor`, *rquired*):
  103. esidual tensor
  104. prob (`float`, *required*):
  105. dropout probability
  106. training (`bool`, *required*):
  107. training mode
  108. """
  109. out = nn.functional.dropout(x, p=prob, training=training)
  110. out = residual + out
  111. return out
  112. def bloom_gelu_forward(x):
  113. """
  114. Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
  115. make the model jitable.
  116. Args:
  117. x (`torch.tensor`, *required*):
  118. input hidden states
  119. """
  120. return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
  121. def bloom_gelu_back(g, x):
  122. """
  123. gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
  124. 0.3989423 * x * torch.exp(-0.5 * x * x)
  125. Args:
  126. g (`torch.tensor`, *required*):
  127. gradient output tensor
  128. x (`torch.tensor`, *required*):
  129. input tensor
  130. """
  131. x = x[0] # x is a tuple of 1 element, needs to unpack it first
  132. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  133. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  134. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
  135. return ff * g
  136. class GeLUFunction(torch.autograd.Function):
  137. @staticmethod
  138. def forward(ctx, input):
  139. ctx.save_for_backward(input)
  140. return bloom_gelu_forward(input)
  141. @staticmethod
  142. def backward(ctx, grad_output):
  143. input = ctx.saved_tensors
  144. tmp = bloom_gelu_back(grad_output, input)
  145. return tmp
  146. class BloomGelu(nn.Module):
  147. """
  148. BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
  149. torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
  150. copied from Megatron-DeepSpeed code and adapted for our needs
  151. See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
  152. """
  153. def __init__(self):
  154. super().__init__()
  155. def forward(self, x):
  156. if self.training:
  157. return GeLUFunction.apply(x)
  158. else:
  159. return bloom_gelu_forward(x)
  160. class BloomScaledSoftmax(nn.Module):
  161. """
  162. fused operation: scaling + mask + softmax
  163. Args:
  164. input_in_fp16 (`bool`, *required*):
  165. flag to indicate if input in fp16 data format.
  166. input_in_bf16 (`bool`, *required*):
  167. flag to indicate if input in bf16 data format.
  168. scaled_masked_softmax_fusion (`bool`, *required*):
  169. flag to indicate user want to use softmax fusion
  170. mask_func (`function`, *required*):
  171. mask function to be applied.
  172. softmax_in_fp32 (`bool`, *required*):
  173. if true, softmax in performed at fp32 precision.
  174. scale (`float`, *required*):
  175. scaling factor used in input tensor scaling.
  176. """
  177. def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
  178. super().__init__()
  179. self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
  180. self.mask_func = mask_func
  181. self.softmax_in_fp32 = softmax_in_fp32
  182. self.scale = scale
  183. if not (self.scale is None or softmax_in_fp32):
  184. raise ValueError("softmax should be in fp32 when scaled")
  185. def forward(self, input, mask, max_positions):
  186. input_dtype = input.dtype
  187. input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
  188. softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
  189. if self.scale is not None:
  190. input = input * self.scale
  191. if mask is None:
  192. mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
  193. mask = mask.to(input.device)
  194. causal_mask = (
  195. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
  196. .view(1, 1, max_positions, max_positions)
  197. .to(input.device)
  198. )
  199. mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
  200. probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
  201. if input_in_16bit and self.softmax_in_fp32:
  202. probs = probs.to(dtype=input_dtype)
  203. return probs