|
@@ -55,6 +55,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
|
output_hidden_states: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
output_router_logits: Optional[bool] = None,
|
|
output_router_logits: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
|
|
+ cache_position: Optional[torch.LongTensor] = None,
|
|
):
|
|
):
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
@@ -70,6 +71,8 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
|
assert (
|
|
assert (
|
|
attention_mask is None or (attention_mask == 1).all()
|
|
attention_mask is None or (attention_mask == 1).all()
|
|
), f"Custom attention masks are not supported, {attention_mask=}"
|
|
), f"Custom attention masks are not supported, {attention_mask=}"
|
|
|
|
+ if cache_position is not None:
|
|
|
|
+ assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
|
|
assert (
|
|
assert (
|
|
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
|
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
|
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
|
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|