from transformers import PretrainedConfig, VisionEncoderDecoderConfig from typing import List class MagiConfig(PretrainedConfig): model_type = "magi" def __init__( self, disable_ocr: bool = False, disable_crop_embeddings: bool = False, disable_detections: bool = False, detection_model_config: dict = None, ocr_model_config: dict = None, crop_embedding_model_config: dict = None, detection_image_preprocessing_config: dict = None, ocr_pretrained_processor_path: str = None, crop_embedding_image_preprocessing_config: dict = None, **kwargs, ): self.disable_ocr = disable_ocr self.disable_crop_embeddings = disable_crop_embeddings self.disable_detections = disable_detections self.detection_model_config = None self.ocr_model_config = None self.crop_embedding_model_config = None if detection_model_config is not None: self.detection_model_config = PretrainedConfig.from_dict(detection_model_config) if ocr_model_config is not None: self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config) if crop_embedding_model_config is not None: self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config) self.detection_image_preprocessing_config = detection_image_preprocessing_config self.ocr_pretrained_processor_path = ocr_pretrained_processor_path self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config super().__init__(**kwargs)