gmastrapas commited on
Commit
4f6f082
1 Parent(s): ed1da94

fix: throw warnings if xformers or flash-attn cant be used

Browse files
Files changed (1) hide show
  1. modeling_clip.py +58 -7
modeling_clip.py CHANGED
@@ -5,6 +5,8 @@
5
  # and adjusted for Jina CLIP
6
 
7
  import base64
 
 
8
  from functools import partial
9
  from io import BytesIO
10
  from typing import List, Optional, Tuple, Union
@@ -117,6 +119,61 @@ def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
117
  )
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  class JinaCLIPPreTrainedModel(PreTrainedModel):
121
  """
122
  An abstract class to handle weights initialization and a simple interface for
@@ -218,16 +275,10 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
218
  f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
219
  )
220
 
 
221
  text_config = config.text_config
222
  vision_config = config.vision_config
223
 
224
- if config.use_text_flash_attn is not None:
225
- text_config.hf_model_config_kwargs['use_flash_attn'] = (
226
- config.use_text_flash_attn
227
- )
228
- if config.use_vision_xformers is not None:
229
- vision_config.x_attention = config.use_vision_xformers
230
-
231
  self.add_projections = config.add_projections
232
  self.projection_dim = config.projection_dim
233
  self.text_embed_dim = text_config.embed_dim
 
5
  # and adjusted for Jina CLIP
6
 
7
  import base64
8
+ import importlib.util
9
+ import warnings
10
  from functools import partial
11
  from io import BytesIO
12
  from typing import List, Optional, Tuple, Union
 
119
  )
120
 
121
 
122
+ def _resolve_attention_libs(config: JinaCLIPConfig):
123
+ use_text_flash_attn = (
124
+ config.use_text_flash_attn
125
+ if config.use_text_flash_attn is not None
126
+ else config.text_config.hf_model_config_kwargs.get('use_flash_attn', True)
127
+ )
128
+ use_vision_xformers = (
129
+ config.use_vision_xformers
130
+ if config.use_vision_xformers is not None
131
+ else config.vision_config.x_attention
132
+ )
133
+
134
+ def _resolve_use_text_flash_attn() -> bool:
135
+ if use_text_flash_attn:
136
+ if not torch.cuda.is_available():
137
+ warnings.warn('Flash attention requires CUDA, disabling')
138
+ return False
139
+ if not importlib.util.find_spec('flash_attn') is None:
140
+ warnings.warn(
141
+ 'Flash attention is not installed. Check '
142
+ 'https://github.com/Dao-AILab/flash-attention?'
143
+ 'tab=readme-ov-file#installation-and-features '
144
+ 'for installation instructions, disabling'
145
+ )
146
+ return False
147
+ return True
148
+ return False
149
+
150
+ def _resolve_use_vision_xformers() -> bool:
151
+ if use_vision_xformers:
152
+ if not torch.cuda.is_available():
153
+ warnings.warn('xFormers requires CUDA, disabling')
154
+ return False
155
+ if not importlib.util.find_spec('xformers') is None:
156
+ warnings.warn(
157
+ 'xFormers is not installed. Check '
158
+ 'https://github.com/facebookresearch/xformers?'
159
+ 'tab=readme-ov-file#installing-xformers for installation '
160
+ 'instructions, disabling'
161
+ )
162
+ return False
163
+ return True
164
+ return False
165
+
166
+ _use_text_flash_attn = _resolve_use_text_flash_attn()
167
+ _use_vision_xformers = _resolve_use_vision_xformers()
168
+
169
+ config.use_text_flash_attn = _use_text_flash_attn
170
+ config.use_vision_xformers = _use_vision_xformers
171
+ config.text_config.hf_model_config_kwargs['use_flash_attn'] = _use_text_flash_attn
172
+ config.vision_config.x_attention = _use_vision_xformers
173
+
174
+ return config
175
+
176
+
177
  class JinaCLIPPreTrainedModel(PreTrainedModel):
178
  """
179
  An abstract class to handle weights initialization and a simple interface for
 
275
  f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
276
  )
277
 
278
+ config = _resolve_attention_libs(config)
279
  text_config = config.text_config
280
  vision_config = config.vision_config
281
 
 
 
 
 
 
 
 
282
  self.add_projections = config.add_projections
283
  self.projection_dim = config.projection_dim
284
  self.text_embed_dim = text_config.embed_dim