瀏覽代碼

support transformers 4.38.2

Your Name 1 年之前
父節點
當前提交
6278395227
共有 3 個文件被更改,包括 16 次插入5 次删除
  1. 1 1
      src/petals/__init__.py
  2. 13 4
      src/petals/models/llama/block.py
  3. 2 0
      src/petals/models/llama/model.py

+ 1 - 1
src/petals/__init__.py

@@ -22,7 +22,7 @@ __version__ = "2.3.0.dev2"
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.39.0")
+        version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
     ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"
 
 

+ 13 - 4
src/petals/models/llama/block.py

@@ -50,9 +50,15 @@ class OptimizedLlamaAttention(LlamaAttention):
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         assert not output_attentions
-        assert position_ids is None
+        if position_ids is None:
+            past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
+            position_ids = torch.arange(
+                past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
+            ).unsqueeze(0)
+
         bsz, q_len, _ = hidden_states.size()
 
         if self.config.pretraining_tp > 1:
@@ -84,9 +90,8 @@ class OptimizedLlamaAttention(LlamaAttention):
         kv_seq_len = key_states.shape[-2]
         if past_key_value is not None:
             kv_seq_len += past_key_value[0].shape[-2]
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        cos = cos[kv_seq_len - q_len :]
-        sin = sin[kv_seq_len - q_len :]
+        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
+        cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
 
         if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
             query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@@ -160,6 +165,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
         output_attentions: Optional[bool] = False,
         use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
     ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         """
         Args:
@@ -190,6 +197,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
             past_key_value=past_key_value,
             output_attentions=output_attentions,
             use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
         )
 
         hidden_states = residual + hidden_states

+ 2 - 0
src/petals/models/llama/model.py

@@ -47,6 +47,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> BaseModelOutputWithPast:
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -62,6 +63,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert cache_position is None, "cache_position is only supported for dedicated inference"
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"