|
@@ -152,7 +152,7 @@ class RemoteGenerationMixin:
|
|
|
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
|
|
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
|
|
|
|
|
- # If samples have padded, so changes only them
|
|
|
+ # If some samples were padded, change only these samples
|
|
|
if seq_idx < inputs.size(1):
|
|
|
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
|
|
last_token_id = (~pad_token_mask) * inputs[
|