Преглед на файлове

Init inference generation

artek0chumak преди 3 години
родител
ревизия
b29937b74c
променени са 4 файла, в които са добавени 95 реда и са изтрити 3 реда
  1. 3 3
      src/client/remote_model.py
  2. 0 0
      src/utils/__init__.py
  3. 57 0
      src/utils/generation_algorithms.py
  4. 35 0
      src/utils/generation_constraints.py

+ 3 - 3
src/client/remote_model.py

@@ -16,6 +16,7 @@ from src.bloom.model import (
     LMHead,
 )
 from src.client.remote_sequential import RemoteSequential
+from src.client.remote_generation import RemoteGenerationMixin
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -137,9 +138,8 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return transformer_outputs
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM):
-    """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
-
+class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
+    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):

+ 0 - 0
src/utils/__init__.py


+ 57 - 0
src/utils/generation_algorithms.py

@@ -0,0 +1,57 @@
+import torch
+
+from abc import ABC
+from typing import Tuple
+
+TokenIds = torch.Tensor
+BatchIds = torch.Tensor
+
+
+class DecodingAlgorithm(ABC):
+    def __init__(self) -> None:
+        pass
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
+        pass
+
+
+class GreedyAlgorithm(DecodingAlgorithm):
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
+        return logits.max(-1)[1], torch.arange(logits.size(0))
+
+
+class TopKAlgorithm(DecodingAlgorithm):
+    # TODO: Add NumHypos, maxBatchSize
+    def __init__(self, top_k: int, temperature: float = 1.0) -> None:
+        self.top_k = top_k
+        self.temperature = temperature
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
+        logits = logits[:, -1]
+        indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
+        logits[indices_to_remove] = -float("Inf")
+        probs = torch.softmax(logits / self.temperature, -1)
+        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+
+
+class NucleusAlgorithm(DecodingAlgorithm):
+    def __init__(self, top_p: float, temperature: float = 1.0) -> None:
+        self.top_p = top_p
+        self.temperature = temperature
+
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, BatchIds]:
+        logits = logits[:, -1]
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        probs = torch.softmax(sorted_logits / self.temperature, -1)
+        cumulative_probs = torch.cumsum(probs, dim=-1)
+        sorted_indices_to_remove = cumulative_probs > self.top_p
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = False
+        indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
+        indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
+        logits[indices_to_remove] = -float("Inf")
+        probs = torch.softmax(logits / self.temperature, -1)
+        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+
+
+# TODO: In generate function we need to check usage of top_k or sampling algorithm

+ 35 - 0
src/utils/generation_constraints.py

@@ -0,0 +1,35 @@
+import torch
+
+from abc import ABC
+
+
+class ABConstraint(ABC):
+    def __init__(self) -> None:
+        pass
+    
+    def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
+        pass
+    
+    def consume_prefix(self, prefix: torch.Tensor) -> None:
+        pass
+    
+    def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
+        pass
+    
+    
+class MaxNewTokensConstraint(ABConstraint):
+    def __init__(self, max_new_tokens: int, eos_token_id: int, min_logits: float = -100000) -> None:
+        self.max_new_tokens = max_new_tokens
+        self.current_generated_tokens = 0
+        self.eos_token_id = eos_token_id
+        self.min_logits = min_logits
+    
+    def update(self, token_id: torch.Tensor, is_started: torch.Tensor) -> None:
+        self.current_generated_tokens += 1
+        
+    def calculate_transation(self, logits: torch.Tensor) -> torch.Tensor:
+        if self.current_generated_tokens > self.max_new_tokens:
+            mask = torch.zeros_like(logits)
+            mask[..., self.eos_token_id] = 1
+            logits += self.min_logits * (1 - mask)
+        return logits