Artem Chumachenko 3 éve
szülő
commit
7ea1ea043b

+ 23 - 15
src/client/remote_generation.py

@@ -13,9 +13,10 @@ class RemoteGenerationMixin:
     The class exposes can be used for:
     The class exposes can be used for:
         - *greedy decoding*.
         - *greedy decoding*.
         - *multinomial sampling*.
         - *multinomial sampling*.
-    
+
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
     """
     """
+
     def generate(
     def generate(
         self,
         self,
         inputs: Optional[torch.Tensor] = None,
         inputs: Optional[torch.Tensor] = None,
@@ -33,7 +34,7 @@ class RemoteGenerationMixin:
     ) -> torch.LongTensor:
     ) -> torch.LongTensor:
         """
         """
         Generates sequences of token ids for models with a language modeling head.
         Generates sequences of token ids for models with a language modeling head.
-        
+
         :param inputs: The input tokens to the model.
         :param inputs: The input tokens to the model.
         :param do_sample: Whether to sample from the model predictions or take the argmax.
         :param do_sample: Whether to sample from the model predictions or take the argmax.
         :param temperature: The temperature to use for sampling.
         :param temperature: The temperature to use for sampling.
@@ -48,9 +49,15 @@ class RemoteGenerationMixin:
         :param model_kwargs: Additional arguments to pass to the model.
         :param model_kwargs: Additional arguments to pass to the model.
         """
         """
 
 
-        assert model_kwargs.get("logits_processor", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
-        assert model_kwargs.get("logits_wrapper", None) is None, "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
-        assert model_kwargs.get("stopping_criteria", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+        assert (
+            model_kwargs.get("logits_processor", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
+        assert (
+            model_kwargs.get("logits_wrapper", None) is None
+        ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
+        assert (
+            model_kwargs.get("stopping_criteria", None) is None
+        ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
 
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
@@ -68,16 +75,16 @@ class RemoteGenerationMixin:
 
 
         constraints = self._get_constraints(
         constraints = self._get_constraints(
             inputs=inputs,
             inputs=inputs,
-            eos_token_id=eos_token_id, 
-            pad_token_id=pad_token_id, 
-            max_new_tokens=max_new_tokens, 
+            eos_token_id=eos_token_id,
+            pad_token_id=pad_token_id,
+            max_new_tokens=max_new_tokens,
             provided_constraints=provided_constraints,
             provided_constraints=provided_constraints,
         )
         )
 
 
         with self.transformer.h.inference_session() as sess:
         with self.transformer.h.inference_session() as sess:
             outputs = []
             outputs = []
-            if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
-                outputs += [inputs[:, :inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
+            if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
+                outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
             else:
             else:
                 outputs += [inputs]
                 outputs += [inputs]
             last_token_id = None
             last_token_id = None
@@ -93,9 +100,11 @@ class RemoteGenerationMixin:
                 for constraint in constraints:
                 for constraint in constraints:
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
                 last_token_id, hypo_ids = decoding_algorithm(lm_logits)
                 last_token_id, hypo_ids = decoding_algorithm(lm_logits)
-                if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
-                    pad_token_mask = inputs[:, seq_idx:seq_idx + 1] == pad_token_id
-                    last_token_id = (~pad_token_mask) * inputs[:, seq_idx:seq_idx + 1] + pad_token_mask * last_token_id
+                if seq_idx < inputs.size(1):  # TODO: why is it not a constraint?
+                    pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
+                    last_token_id = (~pad_token_mask) * inputs[
+                        :, seq_idx : seq_idx + 1
+                    ] + pad_token_mask * last_token_id
 
 
                 if torch.all(last_token_id == eos_token_id):
                 if torch.all(last_token_id == eos_token_id):
                     break
                     break
@@ -147,7 +156,7 @@ class RemoteGenerationMixin:
     ) -> torch.LongTensor:
     ) -> torch.LongTensor:
         """
         """
         Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
         Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
-        
+
         :param: input_ids: The input tokens to the model.
         :param: input_ids: The input tokens to the model.
         :param: temperature: The temperature to use for sampling.
         :param: temperature: The temperature to use for sampling.
         :param: top_k: The number of samples to use for top_k sampling.
         :param: top_k: The number of samples to use for top_k sampling.
@@ -229,4 +238,3 @@ class RemoteGenerationMixin:
             constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
             constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
         constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
         constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
         return constraints
         return constraints
-

+ 1 - 0
src/client/remote_model.py

@@ -144,6 +144,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
 class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
     config_class = DistributedBloomConfig
     config_class = DistributedBloomConfig
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):

+ 1 - 1
src/server/cache.py

@@ -25,7 +25,7 @@ class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
 
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
-        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
         self.device = device
         self.device = device
         self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)

+ 6 - 2
src/server/handler.py

@@ -39,7 +39,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
 
-            cache_metadata = torch.tensor([[-1, -1] for _ in range(batch_size)], dtype=torch.int64)  # [cache_handle, prefix_length]
+            cache_metadata = torch.tensor(
+                [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
+            )  # [cache_handle, prefix_length]
             prefix_length = 0
             prefix_length = 0
 
 
             async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
             async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
@@ -225,7 +227,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 num_heads = backend.module.self_attention.num_heads
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
                 head_dim = backend.module.self_attention.head_dim
 
 
-                cache_descriptor = TensorDescriptor(size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
+                cache_descriptor = TensorDescriptor(
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
 
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))

+ 3 - 1
src/utils/generation_algorithms.py

@@ -11,6 +11,7 @@ class DecodingAlgorithm(ABC):
     """
     """
     An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
     An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
     """
     """
+
     def __init__(self) -> None:
     def __init__(self) -> None:
         pass
         pass
 
 
@@ -26,6 +27,7 @@ class GreedyAlgorithm(DecodingAlgorithm):
     """
     """
     The simpliest algorithm for decoding. It selects the most probable token.
     The simpliest algorithm for decoding. It selects the most probable token.
     """
     """
+
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
         """
         Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
         Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
@@ -38,7 +40,7 @@ class SamplingAlgorithm(DecodingAlgorithm):
         """
         """
         :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
         :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
         :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
         :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
-        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size). 
+        :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
         """
         """
         logits[indices_to_remove] = -float("Inf")
         logits[indices_to_remove] = -float("Inf")
         probs = torch.softmax(logits / self.temperature, -1)
         probs = torch.softmax(logits / self.temperature, -1)

+ 10 - 5
src/utils/generation_constraints.py

@@ -7,10 +7,11 @@ class ABCBloomConstraint(ABC):
     """
     """
     Base class of all kind of decoding constraints. It can be used to implement a new constraint.
     Base class of all kind of decoding constraints. It can be used to implement a new constraint.
     """
     """
+
     def __init__(self) -> None:
     def __init__(self) -> None:
         pass
         pass
 
 
-    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor,  hypo_ids: torch.Tensor) -> torch.Tensor:
+    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
         """
         """
         This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
         This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
         :param tokens_id: The token id of the last choosen token.
         :param tokens_id: The token id of the last choosen token.
@@ -31,7 +32,10 @@ class MaxNewTokensConstraint(ABCBloomConstraint):
         pad_token_id: The id of the padding token.
         pad_token_id: The id of the padding token.
         min_logits: The minimum logits that can be generated. Default: -1e6.
         min_logits: The minimum logits that can be generated. Default: -1e6.
     """
     """
-    def __init__(self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
+
+    def __init__(
+        self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
+    ) -> None:
         self.max_new_tokens = max_new_tokens
         self.max_new_tokens = max_new_tokens
         self.current_generated_tokens = None
         self.current_generated_tokens = None
         self.eos_token_id = eos_token_id
         self.eos_token_id = eos_token_id
@@ -44,7 +48,7 @@ class MaxNewTokensConstraint(ABCBloomConstraint):
         if tokens_id is not None:
         if tokens_id is not None:
             self.current_generated_tokens += 1
             self.current_generated_tokens += 1
 
 
-        mask = (self.current_generated_tokens >= self.max_new_tokens)
+        mask = self.current_generated_tokens >= self.max_new_tokens
         logits += self.min_logits * mask
         logits += self.min_logits * mask
         logits[mask[:, 0], self.eos_token_id] = 0
         logits[mask[:, 0], self.eos_token_id] = 0
         return logits
         return logits
@@ -59,6 +63,7 @@ class EosConstraint(ABCBloomConstraint):
         pad_token_id: The id of the padding token.
         pad_token_id: The id of the padding token.
         min_logits: The minimum logits that can be generated. Default: -1e6.
         min_logits: The minimum logits that can be generated. Default: -1e6.
     """
     """
+
     def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
     def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
         self.eos_token_id = eos_token_id
         self.eos_token_id = eos_token_id
         self.min_logits = min_logits
         self.min_logits = min_logits
@@ -68,10 +73,10 @@ class EosConstraint(ABCBloomConstraint):
 
 
     def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
     def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
         if self.past_tokens is not None:
         if self.past_tokens is not None:
-            mask = ((self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id))
+            mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
             logits += self.min_logits * mask
             logits += self.min_logits * mask
             logits[mask[:, 0], self.eos_token_id] = 0
             logits[mask[:, 0], self.eos_token_id] = 0
-        
+
         if tokens_id is not None:
         if tokens_id is not None:
             self.past_tokens = tokens_id
             self.past_tokens = tokens_id
             self.wait_until_starting -= 1
             self.wait_until_starting -= 1