瀏覽代碼

pack of fixes

Artem Chumachenko 2 年之前
父節點
當前提交
ccfb6520fd
共有 3 個文件被更改,包括 58 次插入18 次删除
  1. 31 2
      src/client/remote_generation.py
  2. 24 13
      src/utils/generation_algorithms.py
  3. 3 3
      tests/test_full_model.py

+ 31 - 2
src/client/remote_generation.py

@@ -40,6 +40,7 @@ class RemoteGenerationMixin:
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
+        num_return_sequences: Optional[int] = None,
         **model_kwargs,
     ) -> torch.LongTensor:
         """
@@ -78,6 +79,8 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
+        batch_size = inputs.size(0)
+
         assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
         if max_length is not None and max_new_tokens is None:
             max_new_tokens = max_length - prefix_length
@@ -93,13 +96,21 @@ class RemoteGenerationMixin:
             if do_sample:
                 decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
             elif num_beams is not None and num_beams > 1:
-                decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=inputs.size(0))
+                decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
             else:
                 decoding_algorithm = GreedyAlgorithm()
 
         if num_beams > 1:
             inputs = torch.cat([inputs] * num_beams, dim=0)
 
+        if num_return_sequences is None:
+            num_return_sequences = 1
+
+        assert num_return_sequences <= num_beams, (
+            f"You want more sequences that beam will have."
+            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
+        )
+
         constraints = self._get_constraints(
             inputs=inputs,
             eos_token_id=eos_token_id,
@@ -118,6 +129,7 @@ class RemoteGenerationMixin:
             last_token_id = None
             seq_idx = outputs[0].size(1)
             hypo_ids = torch.arange(outputs[0].size(0))
+            hypo_ids_map = dict()
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
                 intermediate_prompts = None
@@ -143,12 +155,29 @@ class RemoteGenerationMixin:
                 if num_beams > 1:
                     outputs[-1] = outputs[-1][hypo_ids]
 
+                if num_beams > 1:
+                    hypo_ids_map[len(outputs)] = hypo_ids
+                    cur_hypo_ids = torch.tensor(hypo_ids)
+                    for i in range(len(outputs), 1, -1):
+                        outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
+                        cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
+
                 outputs.append(last_token_id)
                 seq_idx += 1
                 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
                     break
 
-        return torch.cat(outputs, dim=-1)
+        outputs = torch.cat(outputs, dim=-1)
+        
+        if num_beams > 1:
+            pre_return_idx = [
+                torch.arange(idx, num_return_sequences * batch_size, batch_size)
+                for idx in range(batch_size)
+            ]
+            return_idx = torch.cat(pre_return_idx, dim=0)
+            outputs = outputs[return_idx]
+
+        return outputs
 
     def greedy_search(
         self,

+ 24 - 13
src/utils/generation_algorithms.py

@@ -1,5 +1,6 @@
 from abc import ABC
 from typing import Tuple
+from heapq import heappush, heappop
 
 import torch
 
@@ -81,32 +82,42 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self.batch_size = batch_size
 
         self._beams = []
-
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+    
+    def __call__(self, logits: torch.Tensor):
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.log_softmax(sorted_logits, -1)
-
+        
+        self._beams = [(beam[0], beam[1] % self.num_beams) for beam in self._beams]
         if len(self._beams) > 0:
             new_beams = []
             for batch_idx in range(self.batch_size):
                 for beam_idx in range(self.num_beams):
-                    new_beam = self._beams[beam_idx]
+                    probs_idx = batch_idx + beam_idx * self.batch_size
+                    new_beam = self._beams[probs_idx]
                     for hypo_idx in range(self.num_beams):
-                        probs_idx = batch_idx + beam_idx * self.batch_size
-                        new_beams.append((beam_idx, new_beam[1] + probs[probs_idx, hypo_idx].item()))
-            new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
-            self._beams = new_beams[: self.batch_size * self.num_beams]
+                        heappush(
+                            new_beams,
+                            (
+                                new_beam[0] + probs[probs_idx, hypo_idx].item(),
+                                beam_idx * self.num_beams + hypo_idx
+                            )
+                        )
+                        if len(new_beams) > self.batch_size * self.num_beams:
+                            heappop(new_beams)
+            self._beams = new_beams
         else:
             for batch_idx in range(self.batch_size):
                 for beam_idx in range(self.num_beams):
-                    self._beams.append((beam_idx, probs[batch_idx, beam_idx].item()))
-
+                    self._beams.append((probs[batch_idx, beam_idx].item(), beam_idx))
+                    
         return_hypos = []
         return_tokens = []
         for batch_idx in range(self.batch_size):
             for beam_idx in range(self.num_beams):
-                hypo_idx = batch_idx + beam_idx * self.batch_size
-                return_hypos.append(self._beams[hypo_idx][0])
-                return_tokens.append([sorted_indices[batch_idx, beam_idx].item()])
+                beam = self._beams[batch_idx + beam_idx * self.batch_size]
+                hypo_idx = beam[1] // self.num_beams
+                token_idx = beam[1] % self.num_beams
+                return_hypos.append(hypo_idx)
+                return_tokens.append([sorted_indices[hypo_idx, token_idx].item()])
 
         return torch.tensor(return_tokens), torch.tensor(return_hypos)

+ 3 - 3
tests/test_full_model.py

@@ -98,7 +98,8 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
-    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    text = "A cat sat on a mat"
+    inputs = tokenizer(text, return_tensors="pt")["input_ids"]
     remote_outputs = model.generate(
         inputs,
         max_new_tokens=max_new_tokens,
@@ -110,9 +111,8 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
         device=inputs.device,
         length_penalty=0,
         do_early_stopping=False,
-        num_beam_hyps_to_keep=2,
     )
-    hf_inputs = tokenizer(["A cat sat on a mat"] * 2, return_tensors="pt")["input_ids"]
+    hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
     hf_outputs = BloomForCausalLM.beam_search(
         model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
     )