michael-guenther commited on
Commit
ab448a5
1 Parent(s): d7c984c

change use_flash_attn and add x_attention attribute

Browse files
Files changed (1) hide show
  1. configuration_clip.py +8 -4
configuration_clip.py CHANGED
@@ -155,7 +155,8 @@ class JinaCLIPConfig(PretrainedConfig):
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
158
- use_flash_attn: bool = False,
 
159
  **kwargs,
160
  ):
161
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -164,7 +165,8 @@ class JinaCLIPConfig(PretrainedConfig):
164
 
165
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
166
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
167
- self.use_flash_attn = use_flash_attn
 
168
 
169
  super().__init__(**kwargs)
170
 
@@ -261,8 +263,10 @@ class JinaCLIPConfig(PretrainedConfig):
261
  'with default values.'
262
  )
263
 
264
- if use_flash_attn:
265
- text_config.hf_model_config_kwargs.use_flash_attn = use_flash_attn
 
 
266
 
267
  self.text_config = JinaCLIPTextConfig(**text_config)
268
  self.vision_config = JinaCLIPVisionConfig(**vision_config)
 
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
158
+ use_text_flash_attn: Optional[bool] = None,
159
+ use_vision_xformers: Optional[bool] = None,
160
  **kwargs,
161
  ):
162
  # If `_config_dict` exist, we use them for the backward compatibility.
 
165
 
166
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
167
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
+ self.use_text_flash_attn = use_text_flash_attn
169
+ self.use_vision_xformers = use_vision_xformers
170
 
171
  super().__init__(**kwargs)
172
 
 
263
  'with default values.'
264
  )
265
 
266
+ if use_text_flash_attn:
267
+ text_config.hf_model_config_kwargs.use_flash_attn = use_text_flash_attn
268
+ if use_vision_xformers:
269
+ vision_config.x_attention = use_vision_xformers
270
 
271
  self.text_config = JinaCLIPTextConfig(**text_config)
272
  self.vision_config = JinaCLIPVisionConfig(**vision_config)