소스 검색

test cache_position safely

Your Name 1 년 전
부모
커밋
18ec785b31
1개의 변경된 파일2개의 추가작업 그리고 1개의 파일을 삭제
  1. 2 1
      src/petals/models/llama/model.py

+ 2 - 1
src/petals/models/llama/model.py

@@ -63,7 +63,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
-        assert cache_position is None, "cache_position is only supported for dedicated inference"
+        if cache_position is not None:
+            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"