|
@@ -6,13 +6,13 @@ from abc import ABC
|
|
|
class ABConstraint(ABC):
|
|
|
def __init__(self) -> None:
|
|
|
pass
|
|
|
-
|
|
|
+
|
|
|
def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
|
|
|
pass
|
|
|
-
|
|
|
+
|
|
|
def consume_prefix(self, prefix: torch.Tensor) -> None:
|
|
|
pass
|
|
|
-
|
|
|
+
|
|
|
def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
pass
|
|
|
|
|
@@ -26,7 +26,7 @@ class MaxNewTokensConstraint(ABConstraint):
|
|
|
|
|
|
def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
|
|
|
self.current_generated_tokens += 1
|
|
|
-
|
|
|
+
|
|
|
def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
if self.current_generated_tokens > self.max_new_tokens:
|
|
|
mask = torch.zeros_like(logits)
|