shpotes commited on
Commit
5d9f9f0
1 Parent(s): 84d39fc

fix clip_vision trouble

Browse files
Files changed (1) hide show
  1. src/configuration_medclip.py +3 -3
src/configuration_medclip.py CHANGED
@@ -2,7 +2,7 @@ import copy
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
-
6
 
7
  logger = logging.get_logger(__name__)
8
 
@@ -69,12 +69,12 @@ class MedCLIPConfig(PretrainedConfig):
69
  text_model_type = text_config.pop("model_type")
70
  vision_model_type = vision_config.pop("model_type")
71
 
72
- from transformers import AutoConfig
73
-
74
  self.text_config = AutoConfig.for_model(text_model_type, **text_config)
75
 
76
  if vision_model_type == "clip":
77
  self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
 
 
78
  else:
79
  self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
80
 
 
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
+ from transformers import AutoConfig, CLIPVisionConfig
6
 
7
  logger = logging.get_logger(__name__)
8
 
 
69
  text_model_type = text_config.pop("model_type")
70
  vision_model_type = vision_config.pop("model_type")
71
 
 
 
72
  self.text_config = AutoConfig.for_model(text_model_type, **text_config)
73
 
74
  if vision_model_type == "clip":
75
  self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
76
+ elif vision_model_type == "clip_vision_model":
77
+ self.vision_config = CLIPVisionConfig(**vision_config)
78
  else:
79
  self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
80