Explorar o código

working one batch

artek0chumak %!s(int64=3) %!d(string=hai) anos
pai
achega
aee079981a
Modificáronse 1 ficheiros con 2 adicións e 3 borrados
  1. 2 3
      src/client/remote_generation.py

+ 2 - 3
src/client/remote_generation.py

@@ -46,7 +46,7 @@ class RemoteGenerationMixin(PreTrainedModel):
 
         with self.transformer.h.inference_session() as sess:
             last_token_id = inputs[:, -1]
-            outputs = []
+            outputs = [inputs]
             while torch.any(last_token_id != eos_token_id):
                 embs = self.transformer.word_embeddings(inputs)
                 embs = self.transformer.word_embeddings_layernorm(embs)
@@ -61,6 +61,5 @@ class RemoteGenerationMixin(PreTrainedModel):
                     constraint.update(last_token_id, torch.ones_like(last_token_id))
                 outputs.append(last_token_id)
                 inputs = last_token_id
-            
-        return torch.cat(outputs, dim=-1)
 
+        return torch.cat(outputs, dim=-1)