Procházet zdrojové kódy

Use slightly less memory in .generate() (#177)

Alexander Borzunov před 2 roky
rodič
revize
e27706358c
2 změnil soubory, kde provedl 9 přidání a 7 odebrání
  1. 1 1
      README.md
  2. 8 6
      src/petals/client/remote_generation.py

+ 1 - 1
README.md

@@ -50,7 +50,7 @@ sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --
 
 Check out more examples and tutorials:
 
-- Chatbot web app: [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat)
+- Chatbot web app (connects to Petals via an HTTP endpoint): [link](http://chat.petals.ml), [source code](https://github.com/borzunov/petals-chat)
 - Training a personified chatbot: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
 - Fine-tuning BLOOM for text semantic classification: [notebook](https://github.com/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
 - Launching your own swarm: [tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)

+ 8 - 6
src/petals/client/remote_generation.py

@@ -40,7 +40,7 @@ class RemoteGenerationMixin:
 
         return self.transformer.h.inference_session(**kwargs)
 
-    @torch.no_grad()
+    @torch.inference_mode()
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
@@ -171,13 +171,15 @@ class RemoteGenerationMixin:
             seq_idx = outputs[0].size(1)
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
-                embs = self.transformer.word_embeddings(outputs[-1])
+                hidden_state = self.transformer.word_embeddings(outputs[-1])
                 intermediate_prompts = None
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
-                    embs = torch.cat([prompts, embs], dim=1)
-                embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+                    prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
+                    hidden_state = torch.cat([prompts, hidden_state], dim=1)
+                hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
+
+                hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)