|
from transformers import PretrainedConfig, AutoConfig |
|
|
|
|
|
class CLIPEncoderDecoderConfig(PretrainedConfig): |
|
model_type = "clip-encoder-decoder" |
|
|
|
def __init__( |
|
self, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.encoder = AutoConfig.from_pretrained('facebook/convnextv2-base-22k-224') |
|
self.encoder.hidden_size = 1024 |
|
self.decoder = AutoConfig.from_pretrained('airesearch/wangchanberta-base-att-spm-uncased') |
|
self.is_encoder_decoder = True |
|
|
|
@classmethod |
|
def from_encoder_decoder_configs( |
|
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs |
|
) -> PretrainedConfig: |
|
r""" |
|
Instantiate a [`VisionEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model |
|
configuration and decoder model configuration. |
|
|
|
Returns: |
|
[`VisionEncoderDecoderConfig`]: An instance of a configuration object |
|
""" |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) |
|
|