Kaynağa Gözat

working one batch

artek0chumak 3 yıl önce
ebeveyn
işleme
aee079981a
1 değiştirilmiş dosya ile 2 ekleme ve 3 silme
  1. 2 3
      src/client/remote_generation.py

+ 2 - 3
src/client/remote_generation.py

@@ -46,7 +46,7 @@ class RemoteGenerationMixin(PreTrainedModel):
 
         with self.transformer.h.inference_session() as sess:
             last_token_id = inputs[:, -1]
-            outputs = []
+            outputs = [inputs]
             while torch.any(last_token_id != eos_token_id):
                 embs = self.transformer.word_embeddings(inputs)
                 embs = self.transformer.word_embeddings_layernorm(embs)
@@ -61,6 +61,5 @@ class RemoteGenerationMixin(PreTrainedModel):
                     constraint.update(last_token_id, torch.ones_like(last_token_id))
                 outputs.append(last_token_id)
                 inputs = last_token_id
-            
-        return torch.cat(outputs, dim=-1)
 
+        return torch.cat(outputs, dim=-1)