gmastrapas commited on
Commit
83560ca
1 Parent(s): 6ba3c14

fix: throw warnings if xformers or flash-attn cant be used

Browse files
Files changed (1) hide show
  1. configuration_clip.py +14 -0
configuration_clip.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  from copy import deepcopy
9
  from typing import Any, Dict, List, Optional, Union
10
 
 
11
  from transformers import PretrainedConfig, logging
12
 
13
  logger = logging.get_logger(__name__)
@@ -157,6 +158,7 @@ class JinaCLIPConfig(PretrainedConfig):
157
  use_vision_xformers: Optional[bool] = None,
158
  matryoshka_dimensions: Optional[List[int]] = None,
159
  truncate_dim: Optional[int] = None,
 
160
  **kwargs,
161
  ):
162
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -284,6 +286,18 @@ class JinaCLIPConfig(PretrainedConfig):
284
  'projections with `add_projections=True`.'
285
  )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  @classmethod
288
  def from_text_vision_configs(
289
  cls,
 
8
  from copy import deepcopy
9
  from typing import Any, Dict, List, Optional, Union
10
 
11
+ import torch
12
  from transformers import PretrainedConfig, logging
13
 
14
  logger = logging.get_logger(__name__)
 
158
  use_vision_xformers: Optional[bool] = None,
159
  matryoshka_dimensions: Optional[List[int]] = None,
160
  truncate_dim: Optional[int] = None,
161
+ torch_dtype: Optional[Union[str, torch.dtype]] = None,
162
  **kwargs,
163
  ):
164
  # If `_config_dict` exist, we use them for the backward compatibility.
 
286
  'projections with `add_projections=True`.'
287
  )
288
 
289
+ if (
290
+ torch_dtype
291
+ and hasattr(torch, torch_dtype)
292
+ and type(getattr(torch, torch_dtype)) is torch.dtype
293
+ ):
294
+ self.torch_dtype = getattr(torch, torch_dtype)
295
+ else:
296
+ self.torch_dtype = torch_dtype
297
+
298
+ if not self.use_text_flash_attn or not torch.cuda.is_available():
299
+ self.torch_dtype = torch.float32
300
+
301
  @classmethod
302
  def from_text_vision_configs(
303
  cls,