fix clip_vision trouble
Browse files
src/configuration_medclip.py
CHANGED
@@ -2,7 +2,7 @@ import copy
|
|
2 |
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
from transformers.utils import logging
|
5 |
-
|
6 |
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
@@ -69,12 +69,12 @@ class MedCLIPConfig(PretrainedConfig):
|
|
69 |
text_model_type = text_config.pop("model_type")
|
70 |
vision_model_type = vision_config.pop("model_type")
|
71 |
|
72 |
-
from transformers import AutoConfig
|
73 |
-
|
74 |
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
75 |
|
76 |
if vision_model_type == "clip":
|
77 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
|
|
|
|
78 |
else:
|
79 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
80 |
|
|
|
2 |
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
from transformers.utils import logging
|
5 |
+
from transformers import AutoConfig, CLIPVisionConfig
|
6 |
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
|
|
69 |
text_model_type = text_config.pop("model_type")
|
70 |
vision_model_type = vision_config.pop("model_type")
|
71 |
|
|
|
|
|
72 |
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
73 |
|
74 |
if vision_model_type == "clip":
|
75 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
76 |
+
elif vision_model_type == "clip_vision_model":
|
77 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
78 |
else:
|
79 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
80 |
|