test_optimized_layers.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from typing import Optional, Tuple
  2. import pytest
  3. import torch
  4. from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
  5. from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
  6. from petals.utils.auto_config import AutoDistributedConfig
  7. from petals.utils.convert_block import QuantType, convert_block
  8. from test_utils import MODEL_NAME
  9. KVCache = Tuple[torch.Tensor, torch.Tensor]
  10. class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
  11. def forward(
  12. self,
  13. hidden_states: torch.Tensor,
  14. *args,
  15. attention_mask: Optional[torch.Tensor] = None,
  16. alibi: Optional[torch.Tensor] = None,
  17. layer_past: Optional[KVCache] = None,
  18. use_cache: bool = False,
  19. **kwargs,
  20. ):
  21. batch_size, seq_length = hidden_states.shape[:2]
  22. if layer_past is not None:
  23. layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
  24. past_length = 0 if layer_past is None else layer_past[0].shape[1]
  25. seq_length_with_past = seq_length + past_length
  26. attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
  27. if alibi is None and self.config.alibi:
  28. alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
  29. attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
  30. outputs = super().forward(
  31. hidden_states,
  32. *args,
  33. attention_mask=attention_mask,
  34. alibi=alibi,
  35. layer_past=layer_past,
  36. use_cache=use_cache,
  37. **kwargs,
  38. )
  39. if use_cache:
  40. present_key_value = outputs[-1]
  41. present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
  42. outputs = outputs[:-1] + (present_key_value,)
  43. return outputs
  44. def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
  45. key_states, value_states = key_value
  46. key_states = key_states.permute(0, 2, 1)
  47. assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
  48. if self.config.new_decoder_architecture:
  49. key_states = self._expand_states(key_states)
  50. value_states = self._expand_states(value_states)
  51. return (key_states, value_states)
  52. def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
  53. key_states, value_states = key_value
  54. if self.config.new_decoder_architecture:
  55. key_states = self._collapse_states(key_states)
  56. value_states = self._collapse_states(value_states)
  57. assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
  58. key_states = key_states.permute(0, 2, 1)
  59. return (key_states, value_states)
  60. def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
  61. batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
  62. batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
  63. state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
  64. state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
  65. state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
  66. return state
  67. def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
  68. batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
  69. batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
  70. state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
  71. state = state[:, :, 0]
  72. state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
  73. return state
  74. class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
  75. def forward(
  76. self,
  77. hidden_states: torch.Tensor,
  78. *args,
  79. attention_mask: Optional[torch.Tensor] = None,
  80. position_ids: Optional[torch.LongTensor] = None,
  81. layer_past: Optional[Tuple[torch.Tensor]] = None,
  82. use_cache: bool = False,
  83. **kwargs,
  84. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  85. batch_size, seq_length, _ = hidden_states.shape
  86. seq_length_with_past = seq_length
  87. past_key_values_length = 0
  88. past_key_value = layer_past
  89. if past_key_value is not None:
  90. past_key_values_length = past_key_value[0].shape[2]
  91. seq_length_with_past = seq_length_with_past + past_key_values_length
  92. past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
  93. if position_ids is None:
  94. device = hidden_states.device
  95. position_ids = torch.arange(
  96. past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
  97. )
  98. position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
  99. else:
  100. position_ids = position_ids.view(-1, seq_length).long()
  101. # embed positions
  102. if attention_mask is None:
  103. attention_mask = torch.ones(
  104. (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
  105. )
  106. attention_mask = LlamaModel._prepare_decoder_attention_mask(
  107. None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
  108. )
  109. outputs = super().forward(
  110. hidden_states,
  111. *args,
  112. attention_mask=attention_mask,
  113. position_ids=position_ids,
  114. past_key_value=past_key_value,
  115. use_cache=use_cache,
  116. **kwargs,
  117. )
  118. if use_cache:
  119. present_key_value = outputs[-1]
  120. present_key_value = self._reorder_cache_from_llama_to_bloom(
  121. present_key_value, batch_size, seq_length_with_past
  122. )
  123. outputs = outputs[:-1] + (present_key_value,)
  124. return outputs
  125. def _reorder_cache_from_bloom_to_llama(
  126. self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
  127. ) -> Tuple[torch.Tensor]:
  128. key_states, value_states = key_value
  129. key_states = key_states.permute(0, 2, 1)
  130. key_states = key_states.view(
  131. batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
  132. )
  133. value_states = value_states.view(*key_states.shape)
  134. return (key_states, value_states)
  135. def _reorder_cache_from_llama_to_bloom(
  136. self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
  137. ) -> Tuple[torch.Tensor]:
  138. key_states, value_states = key_value
  139. value_states = value_states.view(
  140. batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
  141. )
  142. key_states = key_states.view(*value_states.shape)
  143. key_states = key_states.permute(0, 2, 1)
  144. return (key_states, value_states)
  145. @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
  146. @pytest.mark.forked
  147. def test_optimized_block(device):
  148. if device == "cuda:0" and not torch.cuda.is_available():
  149. pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
  150. config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
  151. tensor_parallel_devices = (device,)
  152. dtype = torch.bfloat16
  153. quant_type = QuantType.NONE
  154. block = config.block_class(config).to(dtype)
  155. block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
  156. if config.model_type == "falcon":
  157. unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
  158. elif config.model_type == "llama":
  159. unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
  160. else:
  161. pytest.skip(f"This test is not applicable to {config.model_type} models")
  162. unopt_block = convert_block(
  163. unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
  164. )
  165. unopt_block.load_state_dict(block.state_dict())
  166. cache = unopt_cache = None
  167. with torch.inference_mode():
  168. for length in [10, 1, 1, 1]:
  169. dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
  170. block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
  171. unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
  172. assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
  173. assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
  174. assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length