5
0

test_optimized_layers.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from typing import Optional, Tuple
  2. import pytest
  3. import torch
  4. from transformers.cache_utils import DynamicCache
  5. from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
  6. from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
  7. from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
  8. from petals.utils.auto_config import AutoDistributedConfig
  9. from petals.utils.convert_block import QuantType, convert_block
  10. from test_utils import MODEL_NAME
  11. KVCache = Tuple[torch.Tensor, torch.Tensor]
  12. class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
  13. def forward(
  14. self,
  15. hidden_states: torch.Tensor,
  16. *args,
  17. attention_mask: Optional[torch.Tensor] = None,
  18. alibi: Optional[torch.Tensor] = None,
  19. layer_past: Optional[KVCache] = None,
  20. use_cache: bool = False,
  21. **kwargs,
  22. ):
  23. batch_size, seq_length = hidden_states.shape[:2]
  24. if layer_past is not None:
  25. layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
  26. past_length = 0 if layer_past is None else layer_past[0].shape[1]
  27. seq_length_with_past = seq_length + past_length
  28. attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
  29. if alibi is None and self.config.alibi:
  30. alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
  31. attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
  32. outputs = super().forward(
  33. hidden_states,
  34. *args,
  35. attention_mask=attention_mask,
  36. alibi=alibi,
  37. layer_past=layer_past,
  38. use_cache=use_cache,
  39. **kwargs,
  40. )
  41. if use_cache:
  42. present_key_value = outputs[-1]
  43. present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
  44. outputs = outputs[:-1] + (present_key_value,)
  45. return outputs
  46. def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
  47. key_states, value_states = key_value
  48. key_states = key_states.permute(0, 2, 1)
  49. assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
  50. if self.config.new_decoder_architecture:
  51. key_states = self._expand_states(key_states)
  52. value_states = self._expand_states(value_states)
  53. return (key_states, value_states)
  54. def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
  55. key_states, value_states = key_value
  56. if self.config.new_decoder_architecture:
  57. key_states = self._collapse_states(key_states)
  58. value_states = self._collapse_states(value_states)
  59. assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
  60. key_states = key_states.permute(0, 2, 1)
  61. return (key_states, value_states)
  62. def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
  63. batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
  64. batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
  65. state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
  66. state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
  67. state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
  68. return state
  69. def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
  70. batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
  71. batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
  72. state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
  73. state = state[:, :, 0]
  74. state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
  75. return state
  76. class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
  77. def forward(
  78. self,
  79. hidden_states: torch.Tensor,
  80. *args,
  81. attention_mask: Optional[torch.Tensor] = None,
  82. position_ids: Optional[torch.LongTensor] = None,
  83. layer_past: Optional[Tuple[torch.Tensor]] = None,
  84. use_cache: bool = False,
  85. **kwargs,
  86. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  87. batch_size, seq_length, _ = hidden_states.shape
  88. seq_length_with_past = seq_length
  89. past_key_values_length = 0
  90. past_key_value = layer_past
  91. if past_key_value is not None:
  92. past_key_values_length = past_key_value[0].shape[2]
  93. seq_length_with_past = seq_length_with_past + past_key_values_length
  94. past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
  95. elif use_cache:
  96. past_key_value = DynamicCache()
  97. if position_ids is None:
  98. device = hidden_states.device
  99. position_ids = torch.arange(
  100. past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
  101. )
  102. position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
  103. else:
  104. position_ids = position_ids.view(-1, seq_length).long()
  105. # embed positions
  106. if attention_mask is None:
  107. attention_mask = torch.ones(
  108. (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
  109. )
  110. attention_mask = _prepare_4d_causal_attention_mask(
  111. attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
  112. )
  113. outputs = super().forward(
  114. hidden_states,
  115. *args,
  116. attention_mask=attention_mask,
  117. position_ids=position_ids,
  118. past_key_value=past_key_value,
  119. use_cache=use_cache,
  120. **kwargs,
  121. )
  122. if use_cache:
  123. present_key_value = outputs[-1]
  124. present_key_value = self._reorder_cache_from_llama_to_bloom(
  125. present_key_value, batch_size, seq_length_with_past
  126. )
  127. outputs = outputs[:-1] + (present_key_value,)
  128. return outputs
  129. def _reorder_cache_from_bloom_to_llama(
  130. self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
  131. ) -> DynamicCache:
  132. key_states, value_states = key_value
  133. key_states = key_states.permute(0, 2, 1)
  134. key_states = key_states.view(
  135. batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
  136. )
  137. value_states = value_states.view(*key_states.shape)
  138. past_key_values = ((key_states, value_states),)
  139. return DynamicCache.from_legacy_cache(past_key_values)
  140. def _reorder_cache_from_llama_to_bloom(
  141. self, key_value: DynamicCache, batch_size: int, seq_length: int
  142. ) -> Tuple[torch.Tensor]:
  143. key_states, value_states = key_value.to_legacy_cache()[0]
  144. value_states = value_states.view(
  145. batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
  146. )
  147. key_states = key_states.view(*value_states.shape)
  148. key_states = key_states.permute(0, 2, 1)
  149. return (key_states, value_states)
  150. @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
  151. @pytest.mark.forked
  152. def test_optimized_block(device):
  153. if device == "cuda:0" and not torch.cuda.is_available():
  154. pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
  155. config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
  156. tensor_parallel_devices = (device,)
  157. dtype = torch.bfloat16
  158. quant_type = QuantType.NONE
  159. block = config.block_class(config).to(dtype)
  160. block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
  161. if config.model_type == "falcon":
  162. unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
  163. elif config.model_type == "llama":
  164. unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
  165. else:
  166. pytest.skip(f"This test is not applicable to {config.model_type} models")
  167. unopt_block = convert_block(
  168. unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
  169. )
  170. unopt_block.load_state_dict(block.state_dict())
  171. cache = unopt_cache = None
  172. with torch.inference_mode():
  173. for length in [10, 1, 1, 1]:
  174. dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
  175. block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
  176. unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
  177. assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
  178. assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
  179. assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length