|
@@ -1,23 +1,54 @@
|
|
|
-from typing import Type
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import Optional, Type
|
|
|
|
|
|
-from transformers import AutoConfig, PretrainedConfig
|
|
|
+from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
|
|
|
|
|
-CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
|
|
|
|
|
|
+@dataclass
|
|
|
+class _ModelClasses:
|
|
|
+ config: Type[PretrainedConfig]
|
|
|
+ model: Optional[Type[PreTrainedModel]] = None
|
|
|
+ model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
|
|
|
+ model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
|
|
|
+
|
|
|
+
|
|
|
+_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
|
|
|
+
|
|
|
+
|
|
|
+def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
|
|
|
+ assert issubclass(config, PretrainedConfig)
|
|
|
+ assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
|
|
|
+
|
|
|
+ _CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+class _AutoDistributedBase:
|
|
|
+ _mapping_field = None # Should be defined in child classes
|
|
|
|
|
|
-class AutoDistributedConfig:
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
|
|
|
config = AutoConfig.from_pretrained(*args, **kwargs)
|
|
|
- if config.model_type not in CONFIG_MAPPING:
|
|
|
+ if config.model_type not in _CLASS_MAPPING:
|
|
|
raise ValueError(f"Petals does not support model type {config.model_type}")
|
|
|
|
|
|
- dist_config_class = CONFIG_MAPPING[config.model_type]
|
|
|
- return dist_config_class.from_pretrained(*args, **kwargs)
|
|
|
+ proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
|
|
|
+ if proper_cls is None:
|
|
|
+ raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
|
|
|
+
|
|
|
+ return proper_cls.from_pretrained(*args, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+class AutoDistributedConfig(_AutoDistributedBase):
|
|
|
+ _mapping_field = "config"
|
|
|
+
|
|
|
+
|
|
|
+class AutoDistributedModel(_AutoDistributedBase):
|
|
|
+ _mapping_field = "model"
|
|
|
+
|
|
|
+
|
|
|
+class AutoDistributedModelForCausalLM(_AutoDistributedBase):
|
|
|
+ _mapping_field = "model_for_causal_lm"
|
|
|
|
|
|
- @staticmethod
|
|
|
- def register(config_class: Type[PretrainedConfig]) -> None:
|
|
|
- assert issubclass(config_class, PretrainedConfig)
|
|
|
- assert config_class.model_type not in CONFIG_MAPPING
|
|
|
|
|
|
- CONFIG_MAPPING[config_class.model_type] = config_class
|
|
|
+class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
|
|
|
+ _mapping_field = "model_for_sequence_classification"
|