michael-guenther commited on
Commit
d7c984c
1 Parent(s): ab3363e

support flash attn in from_pretrained

Browse files
Files changed (1) hide show
  1. configuration_clip.py +5 -0
configuration_clip.py CHANGED
@@ -155,6 +155,7 @@ class JinaCLIPConfig(PretrainedConfig):
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
 
158
  **kwargs,
159
  ):
160
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -163,6 +164,7 @@ class JinaCLIPConfig(PretrainedConfig):
163
 
164
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
165
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
 
166
 
167
  super().__init__(**kwargs)
168
 
@@ -259,6 +261,9 @@ class JinaCLIPConfig(PretrainedConfig):
259
  'with default values.'
260
  )
261
 
 
 
 
262
  self.text_config = JinaCLIPTextConfig(**text_config)
263
  self.vision_config = JinaCLIPVisionConfig(**vision_config)
264
 
 
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
158
+ use_flash_attn: bool = False,
159
  **kwargs,
160
  ):
161
  # If `_config_dict` exist, we use them for the backward compatibility.
 
164
 
165
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
166
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
167
+ self.use_flash_attn = use_flash_attn
168
 
169
  super().__init__(**kwargs)
170
 
 
261
  'with default values.'
262
  )
263
 
264
+ if use_flash_attn:
265
+ text_config.hf_model_config_kwargs.use_flash_attn = use_flash_attn
266
+
267
  self.text_config = JinaCLIPTextConfig(**text_config)
268
  self.vision_config = JinaCLIPVisionConfig(**vision_config)
269