|
@@ -144,7 +144,6 @@ class RemoteGenerationMixin:
|
|
|
last_token_id = None
|
|
|
seq_idx = outputs[0].size(1)
|
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
|
- hypo_ids_map = dict()
|
|
|
while True:
|
|
|
embs = self.transformer.word_embeddings(outputs[-1])
|
|
|
intermediate_prompts = None
|
|
@@ -176,8 +175,7 @@ class RemoteGenerationMixin:
|
|
|
hypo_ids_map[len(outputs)] = hypo_ids
|
|
|
cur_hypo_ids = torch.tensor(hypo_ids)
|
|
|
for i in range(len(outputs), 1, -1):
|
|
|
- outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
|
|
|
- cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
|
|
|
+ outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
|
|
|
|
outputs.append(last_token_id)
|
|
|
seq_idx += 1
|