ソースを参照

Bump transformers to 4.43.1 (#596)

* Update setup.cfg to transformers 4.43.1
* Update __init__.py to transformers 4.43.1
* add cache_position check for Mixtral

Co-authored-by: xtinkt <ant.sinitsin@gmail.com>
Co-authored-by: Anton Sinitsin <30695750+xtinkt@users.noreply.github.com>
justheuristic 1 年間 前
コミット
6477cb85e7
3 ファイル変更6 行追加3 行削除
  1. 1 1
      setup.cfg
  2. 2 2
      src/petals/__init__.py
  3. 3 0
      src/petals/models/mixtral/model.py

+ 1 - 1
setup.cfg

@@ -37,7 +37,7 @@ install_requires =
     accelerate>=0.27.2
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers==4.41.2  # if you change this, please also change version assert in petals/__init__.py
+    transformers==4.43.1  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     hivemind==1.1.10.post2

+ 2 - 2
src/petals/__init__.py

@@ -22,8 +22,8 @@ __version__ = "2.3.0.dev2"
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.41.2") <= version.parse(transformers.__version__) < version.parse("4.42.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.41.2,<4.42.0"
+        version.parse("4.43.1") <= version.parse(transformers.__version__) < version.parse("4.44.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.43.1,<4.44.0"
 
 
 def _override_bfloat16_mode_default():

+ 3 - 0
src/petals/models/mixtral/model.py

@@ -55,6 +55,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
         output_hidden_states: Optional[bool] = None,
         output_router_logits: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ):
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -70,6 +71,8 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        if cache_position is not None:
+            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"