Переглянути джерело

Add position_ids argument to DistributedFalconModel (#525)

Max Ryabinin 1 рік тому
батько
коміт
ae19b65095
1 змінених файлів з 4 додано та 0 видалено
  1. 4 0
      src/petals/models/falcon/model.py

+ 4 - 0
src/petals/models/falcon/model.py

@@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         input_ids: Optional[torch.LongTensor] = None,
         past_key_values: Optional[RemotePastKeyValues] = None,
         attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
         head_mask: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.LongTensor] = None,
         use_cache: Optional[bool] = None,
@@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert (
+            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
+        ), f"Non-consecutive position_ids are not supported, {position_ids=}"
         assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
         assert use_cache is None or use_cache, f"{use_cache=} is not supported"
         assert not output_attentions, f"{output_attentions=} is not supported"