Browse Source

Support Llama 2 (#379)

Alexander Borzunov 2 năm trước cách đây
mục cha
commit
057a2fb5de

+ 1 - 1
src/petals/__init__.py

@@ -11,7 +11,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "1.2.0.dev3"
+__version__ = "1.2.0.dev4"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 14 - 9
src/petals/cli/run_server.py

@@ -25,6 +25,8 @@ def main():
                        help="path or name of a pretrained model, converted with cli/convert_model.py")
     group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
 
+    parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard")
+
     group = parser.add_mutually_exclusive_group(required=False)
     group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
     group.add_argument("--use_auth_token", action="store_true", dest="token",
@@ -59,16 +61,22 @@ def main():
 
     parser.add_argument('--num_handlers', type=int, default=8, required=False,
                         help='server will use this many processes to handle incoming requests')
-    parser.add_argument('--min_batch_size', type=int, default=1,
-                        help='Minimum required batch size for all operations (in total tokens)')
-    parser.add_argument('--max_batch_size', type=int, default=2048,
-                        help='The total number of tokens in the same batch will not exceed this value')
     parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
                         help='Pre-form this many subsequent batches while GPU is processing the current one')
     parser.add_argument('--sender_threads', type=int, default=1, required=False,
                         help='Use this many threads to pass results/exceptions from Runtime to Pools')
-    parser.add_argument('--inference_max_length', type=int, default=2048,
-                        help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
+
+    parser.add_argument('--inference_max_length', type=int, default=None,
+                        help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
+                             'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+    parser.add_argument('--min_batch_size', type=int, default=1,
+                        help='Minimum required batch size for all operations (in total tokens)')
+    parser.add_argument('--max_batch_size', type=int, default=None,
+                        help='The total number of tokens in the same batch will not exceed this value. '
+                             'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+    parser.add_argument('--attn_cache_tokens', type=int, default=None,
+                        help='The number of past attention key/value pairs that will be stored between inference steps. '
+                             'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
 
     parser.add_argument('--cache_dir', type=str, default=None,
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
@@ -86,9 +94,6 @@ def main():
     parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--attn_cache_tokens', type=int, default=8192,
-                        help='The number of past attention key/value pairs that will be stored between inference steps. '
-                             'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
     parser.add_argument('--alloc_timeout', type=float, default=5,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')

+ 3 - 1
src/petals/data_structures.py

@@ -27,12 +27,14 @@ class ServerInfo:
     state: ServerState
     throughput: RPS
 
+    public_name: Optional[str] = None
+    version: Optional[str] = None
+
     network_rps: Optional[RPS] = None
     forward_rps: Optional[RPS] = None
     inference_rps: Optional[RPS] = None
 
     adapters: Sequence[str] = ()
-    version: Optional[str] = None
     torch_dtype: Optional[str] = None
     quant_type: Optional[str] = None
     using_relay: Optional[bool] = None

+ 2 - 0
src/petals/models/bloom/config.py

@@ -18,6 +18,8 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM
     attn_class = BloomAttention
     block_prefix = "h"
 
+    num_key_value_groups = 1
+
     @classmethod
     def from_pretrained(
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs

+ 6 - 2
src/petals/models/llama/block.py

@@ -73,7 +73,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
     ) -> Tuple[torch.Tensor]:
         key_states, value_states = key_value
         key_states = key_states.permute(0, 2, 1)
-        key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
+        key_states = key_states.view(
+            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
         value_states = value_states.view(*key_states.shape)
         return (key_states, value_states)
 
@@ -81,7 +83,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
         self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
     ) -> Tuple[torch.Tensor]:
         key_states, value_states = key_value
-        value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
+        value_states = value_states.view(
+            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
         key_states = key_states.view(*value_states.shape)
         key_states = key_states.permute(0, 2, 1)
         return (key_states, value_states)

+ 11 - 3
src/petals/models/llama/config.py

@@ -18,13 +18,17 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
     attn_class = LlamaAttention
     block_prefix = "model.layers"
 
+    @property
+    def num_key_value_groups(self):
+        return self.num_attention_heads // self.num_key_value_heads
+
     @classmethod
     def from_pretrained(
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
         logger.info(
-            "LLaMA is available solely for non-commercial research purposes. "
-            "Make sure you follow the terms of use: https://bit.ly/llama-license"
+            "Make sure you follow the LLaMA's terms of use: "
+            "https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
         )
 
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
@@ -34,4 +38,8 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
             if not dht_prefix.endswith("-hf"):
                 dht_prefix += "-hf"
             logger.info(f"Using DHT prefix: {dht_prefix}")
-        return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+
+        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+        config = result[0] if isinstance(result, tuple) else result
+        config.pretraining_tp = 1  # This may give less accurate results but it doesn't matter if we use quantization
+        return result

+ 6 - 3
src/petals/server/backend.py

@@ -81,6 +81,7 @@ class TransformerBackend(ModuleBackend):
         head_dim = self.config.hidden_size // self.config.num_attention_heads
         cache_tensors = []
         for device, num_heads in zip(self.module.devices, self.shard_num_heads):
+            num_heads //= self.config.num_key_value_groups
             keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
             values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
             cache_tensors.extend((keys, values))
@@ -123,8 +124,10 @@ class TransformerBackend(ModuleBackend):
         """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
         key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
         for i in range(len(key_cache)):
-            key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
-            value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]  # [batch * num_heads, kv_length, head_dim]
+            key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]
+            # shape: [batch * num_kv_heads, head_dim, kv_length]
+            value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]
+            # shape: [batch * num_kv_heads, kv_length, head_dim]
         layer_past = tuple(chain(*zip(key_cache, value_cache)))
         return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
 
@@ -132,7 +135,7 @@ class TransformerBackend(ModuleBackend):
         self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
     ):
         """Writes new key/value tensors back into cache, works in-place"""
-        _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
+        _batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape
         for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
             new_key = new_key.view(*cache_key.shape[:3], new_length)
             cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]

+ 4 - 0
src/petals/server/from_pretrained.py

@@ -23,6 +23,7 @@ from petals.constants import DTYPE_MAP
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.hf_auth import always_needs_auth
 
 logger = get_logger(__name__)
 
@@ -86,6 +87,9 @@ def _load_state_dict_from_repo(
     cache_dir: str,
     max_disk_space: Optional[int] = None,
 ) -> StateDict:
+    if always_needs_auth(model_name) and token is None:
+        token = True
+
     index_file = get_file_from_repo(
         model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
     )

+ 1 - 2
src/petals/server/reachability.py

@@ -145,8 +145,7 @@ class ReachabilityProtocol(ServicerBase):
                 async with protocol.serve(common_p2p):
                     await protocol._stop.wait()
             except Exception as e:
-                logger.warning(f"Reachability service failed: {repr(e)}")
-                logger.debug("See detailed traceback below:", exc_info=True)
+                logger.debug("Reachability service failed:", exc_info=True)
 
                 if not ready.done():
                     ready.set_exception(e)

+ 20 - 6
src/petals/server/server.py

@@ -50,18 +50,19 @@ class Server:
         initial_peers: List[str],
         dht_prefix: Optional[str],
         converted_model_name_or_path: str,
+        public_name: Optional[str] = None,
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         num_handlers: int = 8,
+        inference_max_length: Optional[int] = None,
         min_batch_size: int = 1,
-        max_batch_size: int = 2048,
-        inference_max_length: int = 2048,
+        max_batch_size: Optional[int] = None,
+        attn_cache_tokens: Optional[int] = None,
         torch_dtype: str = "auto",
         revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
-        attn_cache_tokens: int = 8192,
         alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
@@ -93,8 +94,6 @@ class Server:
         self.converted_model_name_or_path = converted_model_name_or_path
 
         self.num_handlers = num_handlers
-        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
-        self.inference_max_length = inference_max_length
         self.compression = compression
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
@@ -177,8 +176,19 @@ class Server:
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
+        is_multiquery_attn = self.block_config.num_key_value_groups > 1
+        if max_batch_size is None:
+            max_batch_size = 8192 if is_multiquery_attn else 2048
+        if inference_max_length is None:
+            inference_max_length = 8192 if is_multiquery_attn else 2048
+        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+        self.inference_max_length = inference_max_length
+
         # For attention cache in GPU or RAM
+        if attn_cache_tokens is None:
+            attn_cache_tokens = 32768 if is_multiquery_attn else 2048
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
+        cache_values_per_block //= self.block_config.num_key_value_groups
         self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
 
         # For disk cache
@@ -222,8 +232,9 @@ class Server:
             throughput_info = {"throughput": throughput}
         self.server_info = ServerInfo(
             state=ServerState.JOINING,
-            adapters=tuple(adapters),
+            public_name=public_name,
             version=petals.__version__,
+            adapters=tuple(adapters),
             torch_dtype=str(torch_dtype).replace("torch.", ""),
             quant_type=quant_type.name.lower(),
             using_relay=self.dht.client_mode,
@@ -642,7 +653,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self.dht = dht
         self.server_info = server_info
         self.memory_cache = memory_cache
+
         self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
+        self.bytes_per_token //= block_config.num_key_value_groups
+
         self.update_period = update_period
         self.expiration = expiration
         self.trigger = threading.Event()

+ 11 - 4
src/petals/utils/auto_config.py

@@ -1,8 +1,12 @@
+import os
+import re
 from dataclasses import dataclass
-from typing import Optional, Type
+from typing import Optional, Type, Union
 
 from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
 
+from petals.utils.hf_auth import always_needs_auth
+
 
 @dataclass
 class _ModelClasses:
@@ -26,8 +30,11 @@ class _AutoDistributedBase:
     _mapping_field = None  # Should be defined in child classes
 
     @classmethod
-    def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
-        config = AutoConfig.from_pretrained(*args, **kwargs)
+    def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
+        if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs:
+            kwargs["token"] = True
+
+        config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
         if config.model_type not in _CLASS_MAPPING:
             raise ValueError(f"Petals does not support model type {config.model_type}")
 
@@ -35,7 +42,7 @@ class _AutoDistributedBase:
         if proper_cls is None:
             raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
 
-        return proper_cls.from_pretrained(*args, **kwargs)
+        return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
 
 
 class AutoDistributedConfig(_AutoDistributedBase):

+ 7 - 2
src/petals/utils/convert_block.py

@@ -2,6 +2,7 @@
 Tools for converting transformer blocks, applying quantization and/or tensor parallelism
 """
 import re
+from enum import Enum
 from typing import Optional, Sequence
 
 import tensor_parallel as tp
@@ -11,12 +12,16 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 
-from petals.utils.misc import QuantType
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
+class QuantType(Enum):
+    NONE = 0
+    INT8 = 1  # 8-bit as in the LLM.int8() paper
+    NF4 = 2  # 4-bit as in the QLoRA paper
+
+
 def convert_block(
     block: nn.Module,
     block_index: int,

+ 7 - 0
src/petals/utils/hf_auth.py

@@ -0,0 +1,7 @@
+import os
+from typing import Union
+
+
+def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
+    loading_from_repo = model_name is not None and not os.path.isdir(model_name)
+    return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")

+ 0 - 9
src/petals/utils/misc.py

@@ -1,14 +1,5 @@
-from enum import Enum
-
 import torch
 
-
-class QuantType(Enum):
-    NONE = 0
-    INT8 = 1  # 8-bit as in the LLM.int8() paper
-    NF4 = 2  # 4-bit as in the QLoRA paper
-
-
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
 
 

+ 1 - 1
src/petals/utils/peft.py

@@ -17,8 +17,8 @@ from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 
 from petals.server.block_utils import resolve_block_dtype
+from petals.utils.convert_block import QuantType
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
-from petals.utils.misc import QuantType
 
 logger = get_logger(__name__)