|
import copy |
|
import os |
|
|
|
from transformers import AutoConfig, AutoModelForCTC, PretrainedConfig |
|
from transformers.dynamic_module_utils import ( |
|
get_class_from_dynamic_module, |
|
resolve_trust_remote_code, |
|
) |
|
from transformers.models.auto.auto_factory import _get_model_class |
|
|
|
from .extractors import Conv2dFeatureExtractor |
|
|
|
|
|
class FeatureExtractionInitModifier(type): |
|
def __new__(cls, name, bases, dct): |
|
|
|
new_cls = super().__new__(cls, name, bases, dct) |
|
|
|
|
|
original_init = new_cls.__init__ |
|
|
|
|
|
def new_init(self, *args, **kwargs): |
|
original_init(self, *args, **kwargs) |
|
if self.config.expect_2d_input: |
|
getattr(self, self.base_model_prefix).feature_extractor = Conv2dFeatureExtractor(self.config) |
|
|
|
|
|
new_cls.__init__ = new_init |
|
|
|
return new_cls |
|
|
|
|
|
class CustomAutoModelForCTC(AutoModelForCTC): |
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
config = kwargs.pop("config", None) |
|
trust_remote_code = kwargs.pop("trust_remote_code", None) |
|
kwargs["_from_auto"] = True |
|
hub_kwargs_names = [ |
|
"cache_dir", |
|
"code_revision", |
|
"force_download", |
|
"local_files_only", |
|
"proxies", |
|
"resume_download", |
|
"revision", |
|
"subfolder", |
|
"use_auth_token", |
|
] |
|
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} |
|
if not isinstance(config, PretrainedConfig): |
|
kwargs_orig = copy.deepcopy(kwargs) |
|
|
|
|
|
if kwargs.get("torch_dtype", None) == "auto": |
|
_ = kwargs.pop("torch_dtype") |
|
|
|
config, kwargs = AutoConfig.from_pretrained( |
|
pretrained_model_name_or_path, |
|
return_unused_kwargs=True, |
|
trust_remote_code=trust_remote_code, |
|
**hub_kwargs, |
|
**kwargs, |
|
) |
|
|
|
|
|
if kwargs_orig.get("torch_dtype", None) == "auto": |
|
kwargs["torch_dtype"] = "auto" |
|
|
|
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map |
|
has_local_code = type(config) in cls._model_mapping.keys() |
|
trust_remote_code = resolve_trust_remote_code( |
|
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code |
|
) |
|
if has_remote_code and trust_remote_code: |
|
class_ref = config.auto_map[cls.__name__] |
|
model_class = get_class_from_dynamic_module( |
|
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs |
|
) |
|
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {}) |
|
_ = hub_kwargs.pop("code_revision", None) |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
model_class.register_for_auto_class(cls.__name__) |
|
else: |
|
cls.register(config.__class__, model_class, exist_ok=True) |
|
return model_class.from_pretrained( |
|
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs |
|
) |
|
elif type(config) in cls._model_mapping.keys(): |
|
model_class = _get_model_class(config, cls._model_mapping) |
|
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {}) |
|
return model_class.from_pretrained( |
|
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs |
|
) |
|
raise ValueError( |
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" |
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." |
|
) |
|
|
|
@classmethod |
|
def from_config(cls, config, **kwargs): |
|
trust_remote_code = kwargs.pop("trust_remote_code", None) |
|
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map |
|
has_local_code = type(config) in cls._model_mapping.keys() |
|
trust_remote_code = resolve_trust_remote_code( |
|
trust_remote_code, config._name_or_path, has_local_code, has_remote_code |
|
) |
|
|
|
if has_remote_code and trust_remote_code: |
|
class_ref = config.auto_map[cls.__name__] |
|
if "--" in class_ref: |
|
repo_id, class_ref = class_ref.split("--") |
|
else: |
|
repo_id = config.name_or_path |
|
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) |
|
if os.path.isdir(config._name_or_path): |
|
model_class.register_for_auto_class(cls.__name__) |
|
else: |
|
cls.register(config.__class__, model_class, exist_ok=True) |
|
_ = kwargs.pop("code_revision", None) |
|
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {}) |
|
return model_class._from_config(config, **kwargs) |
|
elif type(config) in cls._model_mapping.keys(): |
|
model_class = _get_model_class(config, cls._model_mapping) |
|
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {}) |
|
return model_class._from_config(config, **kwargs) |
|
|
|
raise ValueError( |
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" |
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." |
|
) |
|
|