Artem Chumachenko 3 anni fa
parent
commit
7ea1ea043b

+ 23 - 15
src/client/remote_generation.py

@@ -13,9 +13,10 @@ class RemoteGenerationMixin:
     The class exposes can be used for:
         - *greedy decoding*.
         - *multinomial sampling*.
-    
+
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
     """
+
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
@@ -33,7 +34,7 @@ class RemoteGenerationMixin:
     ) -> torch.LongTensor:
         """
         Generates sequences of token ids for models with a language modeling head.
-        
+
         :param inputs: The input tokens to the model.
         :param do_sample: Whether to sample from the model predictions or take the argmax.
         :param temperature: The temperature to use for sampling.
@@ -48,9 +49,15 @@ class RemoteGenerationMixin:
         :param model_kwargs: Additional arguments to pass to the model.
         """
 
-        assert model_kwargs.get("logits_processor", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
-        assert model_kwargs.get("logits_wrapper", None) is None, "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
-        assert model_kwargs.get("stopping_criteria", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+        assert (
+            model_kwargs.get("logits_processor", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
+        assert (
+            model_kwargs.get("logits_wrapper", None) is None
+        ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
+        assert (
+            model_kwargs.get("stopping_criteria", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
@@ -68,16 +75,16 @@ class RemoteGenerationMixin:
 
         constraints = self._get_constraints(
             inputs=inputs,
-            eos_token_id=eos_token_id, 
-            pad_token_id=pad_token_id, 
-            max_new_tokens=max_new_tokens, 
+            eos_token_id=eos_token_id,
+            pad_token_id=pad_token_id,
+            max_new_tokens=max_new_tokens,
             provided_constraints=provided_constraints,
         )
 
         with self.transformer.h.inference_session() as sess:
             outputs = []
-            if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
-                outputs += [inputs[:, :inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
+            if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
+                outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
             else:
                 outputs += [inputs]
             last_token_id = None
@@ -93,9 +100,11 @@ class RemoteGenerationMixin:
                 for constraint in constraints:
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
                 last_token_id, hypo_ids = decoding_algorithm(lm_logits)
-                if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
-                    pad_token_mask = inputs[:, seq_idx:seq_idx + 1] == pad_token_id
-                    last_token_id = (~pad_token_mask) * inputs[:, seq_idx:seq_idx + 1] + pad_token_mask * last_token_id
+                if seq_idx < inputs.size(1):  # TODO: why is it not a constraint?
+                    pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
+                    last_token_id = (~pad_token_mask) * inputs[
+                        :, seq_idx : seq_idx + 1
+                    ] + pad_token_mask * last_token_id
 
                 if torch.all(last_token_id == eos_token_id):
                     break
@@ -147,7 +156,7 @@ class RemoteGenerationMixin:
     ) -> torch.LongTensor:
         """
         Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
-        
+
         :param: input_ids: The input tokens to the model.
         :param: temperature: The temperature to use for sampling.
         :param: top_k: The number of samples to use for top_k sampling.
@@ -229,4 +238,3 @@ class RemoteGenerationMixin:
             constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
         constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
         return constraints
-

+ 1 - 0
src/client/remote_model.py

@@ -144,6 +144,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
 class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):

+ 1 - 1
src/server/cache.py

@@ -25,7 +25,7 @@ class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
-        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
         self.device = device
         self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)

+ 6 - 2
src/server/handler.py

@@ -39,7 +39,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
-            cache_metadata = torch.tensor([[-1, -1] for _ in range(batch_size)], dtype=torch.int64)  # [cache_handle, prefix_length]
+            cache_metadata = torch.tensor(
+                [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
+            )  # [cache_handle, prefix_length]
             prefix_length = 0
 
             async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
@@ -225,7 +227,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
 
-                cache_descriptor = TensorDescriptor(size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
+                cache_descriptor = TensorDescriptor(
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))

+ 3 - 1
src/utils/generation_algorithms.py

@@ -11,6 +11,7 @@ class DecodingAlgorithm(ABC):
     """
     An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
     """
+
     def __init__(self) -> None:
         pass
 
@@ -26,6 +27,7 @@ class GreedyAlgorithm(DecodingAlgorithm):
     """
     The simpliest algorithm for decoding. It selects the most probable token.
     """
+
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
         Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
@@ -38,7 +40,7 @@ class SamplingAlgorithm(DecodingAlgorithm):
         """
         :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
         :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
-        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size). 
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
         """
         logits[indices_to_remove] = -float("Inf")
         probs = torch.softmax(logits / self.temperature, -1)

+ 10 - 5
src/utils/generation_constraints.py

@@ -7,10 +7,11 @@ class ABCBloomConstraint(ABC):
     """
     Base class of all kind of decoding constraints. It can be used to implement a new constraint.
     """
+
     def __init__(self) -> None:
         pass
 
-    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor,  hypo_ids: torch.Tensor) -> torch.Tensor:
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
         """
         This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
         :param tokens_id: The token id of the last choosen token.
@@ -31,7 +32,10 @@ class MaxNewTokensConstraint(ABCBloomConstraint):
         pad_token_id: The id of the padding token.
         min_logits: The minimum logits that can be generated. Default: -1e6.
     """
-    def __init__(self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
+
+    def __init__(
+        self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
+    ) -> None:
         self.max_new_tokens = max_new_tokens
         self.current_generated_tokens = None
         self.eos_token_id = eos_token_id
@@ -44,7 +48,7 @@ class MaxNewTokensConstraint(ABCBloomConstraint):
         if tokens_id is not None:
             self.current_generated_tokens += 1
 
-        mask = (self.current_generated_tokens >= self.max_new_tokens)
+        mask = self.current_generated_tokens >= self.max_new_tokens
         logits += self.min_logits * mask
         logits[mask[:, 0], self.eos_token_id] = 0
         return logits
@@ -59,6 +63,7 @@ class EosConstraint(ABCBloomConstraint):
         pad_token_id: The id of the padding token.
         min_logits: The minimum logits that can be generated. Default: -1e6.
     """
+
     def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
         self.eos_token_id = eos_token_id
         self.min_logits = min_logits
@@ -68,10 +73,10 @@ class EosConstraint(ABCBloomConstraint):
 
     def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
         if self.past_tokens is not None:
-            mask = ((self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id))
+            mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
             logits += self.min_logits * mask
             logits[mask[:, 0], self.eos_token_id] = 0
-        
+
         if tokens_id is not None:
             self.past_tokens = tokens_id
             self.wait_until_starting -= 1