Sfoglia il codice sorgente

Make beam_search identical

Artem Chumachenko 2 anni fa
parent
commit
9bde866eb3

+ 4 - 2
src/client/remote_generation.py

@@ -140,11 +140,13 @@ class RemoteGenerationMixin:
                         :, seq_idx : seq_idx + 1
                     ] + pad_token_mask * last_token_id
 
-                if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens:
-                    break
+                if num_beams > 1:
+                    outputs[-1] = outputs[-1][hypo_ids]
 
                 outputs.append(last_token_id)
                 seq_idx += 1
+                if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
+                    break
 
         return torch.cat(outputs, dim=-1)
 

+ 26 - 22
src/utils/generation_algorithms.py

@@ -80,29 +80,33 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self._cur_num_beams = 1
         self.batch_size = batch_size
 
-        self._logits = torch.zeros(
-            (
-                self.batch_size,
-                self._cur_num_beams,
-            )
-        )
+        self._beams = []
 
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
-        probs = torch.softmax(sorted_logits, -1)
-
-        new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
+        probs = torch.log_softmax(sorted_logits, -1)
+
+        if len(self._beams) > 0:
+            new_beams = []
+            for batch_idx in range(self.batch_size):
+                for beam_idx in range(self.num_beams):
+                    new_beam = self._beams[beam_idx]
+                    for hypo_idx in range(self.num_beams):
+                        probs_idx = batch_idx + beam_idx * self.batch_size
+                        new_beams.append((beam_idx, new_beam[1] + probs[probs_idx, hypo_idx].item()))
+            new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
+            self._beams = new_beams[: self.batch_size * self.num_beams]
+        else:
+            for batch_idx in range(self.batch_size):
+                for beam_idx in range(self.num_beams):
+                    self._beams.append((beam_idx, probs[batch_idx, beam_idx].item()))
+
+        return_hypos = []
+        return_tokens = []
         for batch_idx in range(self.batch_size):
-            for cur_beam_idx in range(self._cur_num_beams):
-                for new_beam_idx in range(self.num_beams):
-                    logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx]
-                    new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit
-        self._cur_num_beams = self.num_beams
-
-        new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
-        new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
-        self._logits = new_sorted_logits[:, : self.num_beams]
-        result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
-        result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
-
-        return result_tokens.unsqueeze(-1), result_hypos
+            for beam_idx in range(self.num_beams):
+                hypo_idx = batch_idx + beam_idx * self.batch_size
+                return_hypos.append(self._beams[hypo_idx][0])
+                return_tokens.append([sorted_indices[batch_idx, beam_idx].item()])
+
+        return torch.tensor(return_tokens), torch.tensor(return_hypos)

+ 27 - 0
tests/test_full_model.py

@@ -3,6 +3,7 @@ import torch
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
+from transformers.generation_utils import BeamSearchScorer
 
 from src.bloom.model import BloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
@@ -89,3 +90,29 @@ def test_greedy_generation(max_new_tokens=4):
     assert torch.allclose(
         remote_outputs_batch, hf_outputs_batch
     ), "Greedy search are not identical to HF in multibatch mode"
+
+
+@pytest.mark.forked
+def test_greedy_generation(max_new_tokens=4, num_beams=2):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    remote_outputs = model.generate(
+        inputs,
+        max_new_tokens=max_new_tokens,
+        num_beams=num_beams,
+    )
+    beam_scorer = BeamSearchScorer(
+        batch_size=inputs.size(0),
+        num_beams=num_beams,
+        device=inputs.device,
+        length_penalty=0,
+        do_early_stopping=False,
+        num_beam_hyps_to_keep=2,
+    )
+    hf_outputs = BloomForCausalLM.beam_search(
+        model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
+    )
+    assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"