|
@@ -40,6 +40,7 @@ class RemoteGenerationMixin:
|
|
|
max_new_tokens: Optional[int] = None,
|
|
|
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
+ num_return_sequences: Optional[int] = None,
|
|
|
**model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
@@ -78,6 +79,8 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
|
|
+ batch_size = inputs.size(0)
|
|
|
+
|
|
|
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
|
|
if max_length is not None and max_new_tokens is None:
|
|
|
max_new_tokens = max_length - prefix_length
|
|
@@ -93,13 +96,21 @@ class RemoteGenerationMixin:
|
|
|
if do_sample:
|
|
|
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
|
|
elif num_beams is not None and num_beams > 1:
|
|
|
- decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=inputs.size(0))
|
|
|
+ decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
|
|
else:
|
|
|
decoding_algorithm = GreedyAlgorithm()
|
|
|
|
|
|
if num_beams > 1:
|
|
|
inputs = torch.cat([inputs] * num_beams, dim=0)
|
|
|
|
|
|
+ if num_return_sequences is None:
|
|
|
+ num_return_sequences = 1
|
|
|
+
|
|
|
+ assert num_return_sequences <= num_beams, (
|
|
|
+ f"You want more sequences that beam will have."
|
|
|
+ " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
|
+ )
|
|
|
+
|
|
|
constraints = self._get_constraints(
|
|
|
inputs=inputs,
|
|
|
eos_token_id=eos_token_id,
|
|
@@ -118,6 +129,7 @@ class RemoteGenerationMixin:
|
|
|
last_token_id = None
|
|
|
seq_idx = outputs[0].size(1)
|
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
|
+ hypo_ids_map = dict()
|
|
|
while True:
|
|
|
embs = self.transformer.word_embeddings(outputs[-1])
|
|
|
intermediate_prompts = None
|
|
@@ -143,12 +155,29 @@ class RemoteGenerationMixin:
|
|
|
if num_beams > 1:
|
|
|
outputs[-1] = outputs[-1][hypo_ids]
|
|
|
|
|
|
+ if num_beams > 1:
|
|
|
+ hypo_ids_map[len(outputs)] = hypo_ids
|
|
|
+ cur_hypo_ids = torch.tensor(hypo_ids)
|
|
|
+ for i in range(len(outputs), 1, -1):
|
|
|
+ outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
|
|
|
+ cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
|
|
|
+
|
|
|
outputs.append(last_token_id)
|
|
|
seq_idx += 1
|
|
|
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
|
|
break
|
|
|
|
|
|
- return torch.cat(outputs, dim=-1)
|
|
|
+ outputs = torch.cat(outputs, dim=-1)
|
|
|
+
|
|
|
+ if num_beams > 1:
|
|
|
+ pre_return_idx = [
|
|
|
+ torch.arange(idx, num_return_sequences * batch_size, batch_size)
|
|
|
+ for idx in range(batch_size)
|
|
|
+ ]
|
|
|
+ return_idx = torch.cat(pre_return_idx, dim=0)
|
|
|
+ outputs = outputs[return_idx]
|
|
|
+
|
|
|
+ return outputs
|
|
|
|
|
|
def greedy_search(
|
|
|
self,
|