|
@@ -46,7 +46,7 @@ class RemoteGenerationMixin(PreTrainedModel):
|
|
|
|
|
|
with self.transformer.h.inference_session() as sess:
|
|
|
last_token_id = inputs[:, -1]
|
|
|
- outputs = []
|
|
|
+ outputs = [inputs]
|
|
|
while torch.any(last_token_id != eos_token_id):
|
|
|
embs = self.transformer.word_embeddings(inputs)
|
|
|
embs = self.transformer.word_embeddings_layernorm(embs)
|
|
@@ -61,6 +61,5 @@ class RemoteGenerationMixin(PreTrainedModel):
|
|
|
constraint.update(last_token_id, torch.ones_like(last_token_id))
|
|
|
outputs.append(last_token_id)
|
|
|
inputs = last_token_id
|
|
|
-
|
|
|
- return torch.cat(outputs, dim=-1)
|
|
|
|
|
|
+ return torch.cat(outputs, dim=-1)
|