|
@@ -21,6 +21,7 @@ from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
|
from huggingface_hub.utils import EntryNotFoundError
|
|
from huggingface_hub.utils import EntryNotFoundError
|
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
from transformers.utils import get_file_from_repo
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
+from transformers.quantizers import AutoHfQuantizer
|
|
|
|
|
|
from petals.constants import DTYPE_MAP
|
|
from petals.constants import DTYPE_MAP
|
|
from petals.models.mixtral import WrappedMixtralBlock
|
|
from petals.models.mixtral import WrappedMixtralBlock
|
|
@@ -28,6 +29,7 @@ from petals.server.block_utils import get_model_block, resolve_block_dtype
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
from petals.utils.hf_auth import always_needs_auth
|
|
from petals.utils.hf_auth import always_needs_auth
|
|
|
|
+from petals.utils.convert_block import is_gptq_quant
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -55,6 +57,18 @@ def load_pretrained_block(
|
|
block = get_model_block(config, layer_idx=block_index)
|
|
block = get_model_block(config, layer_idx=block_index)
|
|
|
|
|
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
|
|
|
+
|
|
|
|
+ if is_gptq_quant(config):
|
|
|
|
+ hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
|
|
|
|
+ hf_quantizer.optimum_quantizer.block_name_to_quantize = str(block_index)
|
|
|
|
+ tmp_block_list = torch.nn.ModuleList([block])
|
|
|
|
+ tmp_block_list.__class__.main_input_name = "input_ids"
|
|
|
|
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
|
|
|
+ device_map = hf_quantizer.update_device_map("cuda")
|
|
|
|
+ hf_quantizer.preprocess_model(
|
|
|
|
+ model=tmp_block_list, device_map=device_map, keep_in_fp32_modules=False,
|
|
|
|
+ )
|
|
|
|
+
|
|
state_dict = _load_state_dict_from_repo(
|
|
state_dict = _load_state_dict_from_repo(
|
|
model_name,
|
|
model_name,
|
|
block_prefix,
|
|
block_prefix,
|