소스 검색

Add AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification} (#329)

This PR adds `petals.AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification}` classes, similar to their `transformers.Auto{Model, ModelForCausalLM, ModelForSequenceClassification}` counterparts.
Alexander Borzunov 2 년 전
부모
커밋
7a37513f77

+ 8 - 0
src/petals/models/bloom/__init__.py

@@ -5,3 +5,11 @@ from petals.models.bloom.model import (
     DistributedBloomForSequenceClassification,
     DistributedBloomModel,
 )
+from petals.utils.auto_config import register_model_classes
+
+register_model_classes(
+    config=DistributedBloomConfig,
+    model=DistributedBloomModel,
+    model_for_causal_lm=DistributedBloomForCausalLM,
+    model_for_sequence_classification=DistributedBloomForSequenceClassification,
+)

+ 0 - 3
src/petals/models/bloom/config.py

@@ -30,6 +30,3 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM
             dht_prefix = str(model_name_or_path) + "-petals"
             logger.info(f"Using DHT prefix: {dht_prefix}")
         return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
-
-
-AutoDistributedConfig.register(DistributedBloomConfig)

+ 8 - 0
src/petals/models/llama/__init__.py

@@ -5,3 +5,11 @@ from petals.models.llama.model import (
     DistributedLlamaForSequenceClassification,
     DistributedLlamaModel,
 )
+from petals.utils.auto_config import register_model_classes
+
+register_model_classes(
+    config=DistributedLlamaConfig,
+    model=DistributedLlamaModel,
+    model_for_causal_lm=DistributedLlamaForCausalLM,
+    model_for_sequence_classification=DistributedLlamaForSequenceClassification,
+)

+ 0 - 3
src/petals/models/llama/config.py

@@ -30,6 +30,3 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
                 dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
             logger.info(f"Using DHT prefix: {dht_prefix}")
         return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
-
-
-AutoDistributedConfig.register(DistributedLlamaConfig)

+ 6 - 1
src/petals/utils/__init__.py

@@ -1 +1,6 @@
-from petals.utils.auto_config import AutoDistributedConfig
+from petals.utils.auto_config import (
+    AutoDistributedConfig,
+    AutoDistributedModel,
+    AutoDistributedModelForCausalLM,
+    AutoDistributedModelForSequenceClassification,
+)

+ 43 - 12
src/petals/utils/auto_config.py

@@ -1,23 +1,54 @@
-from typing import Type
+from dataclasses import dataclass
+from typing import Optional, Type
 
-from transformers import AutoConfig, PretrainedConfig
+from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
 
-CONFIG_MAPPING = {}  # Populated with AutoDistributedConfig.register()
 
+@dataclass
+class _ModelClasses:
+    config: Type[PretrainedConfig]
+    model: Optional[Type[PreTrainedModel]] = None
+    model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
+    model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
+
+
+_CLASS_MAPPING = {}  # Populated by petals.models.* subpackages with register_model_classes()
+
+
+def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
+    assert issubclass(config, PretrainedConfig)
+    assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
+
+    _CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
+
+
+class _AutoDistributedBase:
+    _mapping_field = None  # Should be defined in child classes
 
-class AutoDistributedConfig:
     @classmethod
     def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
         config = AutoConfig.from_pretrained(*args, **kwargs)
-        if config.model_type not in CONFIG_MAPPING:
+        if config.model_type not in _CLASS_MAPPING:
             raise ValueError(f"Petals does not support model type {config.model_type}")
 
-        dist_config_class = CONFIG_MAPPING[config.model_type]
-        return dist_config_class.from_pretrained(*args, **kwargs)
+        proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
+        if proper_cls is None:
+            raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
+
+        return proper_cls.from_pretrained(*args, **kwargs)
+
+
+class AutoDistributedConfig(_AutoDistributedBase):
+    _mapping_field = "config"
+
+
+class AutoDistributedModel(_AutoDistributedBase):
+    _mapping_field = "model"
+
+
+class AutoDistributedModelForCausalLM(_AutoDistributedBase):
+    _mapping_field = "model_for_causal_lm"
 
-    @staticmethod
-    def register(config_class: Type[PretrainedConfig]) -> None:
-        assert issubclass(config_class, PretrainedConfig)
-        assert config_class.model_type not in CONFIG_MAPPING
 
-        CONFIG_MAPPING[config_class.model_type] = config_class
+class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
+    _mapping_field = "model_for_sequence_classification"