gmastrapas
commited on
Commit
•
83560ca
1
Parent(s):
6ba3c14
fix: throw warnings if xformers or flash-attn cant be used
Browse files- configuration_clip.py +14 -0
configuration_clip.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
from copy import deepcopy
|
9 |
from typing import Any, Dict, List, Optional, Union
|
10 |
|
|
|
11 |
from transformers import PretrainedConfig, logging
|
12 |
|
13 |
logger = logging.get_logger(__name__)
|
@@ -157,6 +158,7 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
157 |
use_vision_xformers: Optional[bool] = None,
|
158 |
matryoshka_dimensions: Optional[List[int]] = None,
|
159 |
truncate_dim: Optional[int] = None,
|
|
|
160 |
**kwargs,
|
161 |
):
|
162 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
@@ -284,6 +286,18 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
284 |
'projections with `add_projections=True`.'
|
285 |
)
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
@classmethod
|
288 |
def from_text_vision_configs(
|
289 |
cls,
|
|
|
8 |
from copy import deepcopy
|
9 |
from typing import Any, Dict, List, Optional, Union
|
10 |
|
11 |
+
import torch
|
12 |
from transformers import PretrainedConfig, logging
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
|
|
158 |
use_vision_xformers: Optional[bool] = None,
|
159 |
matryoshka_dimensions: Optional[List[int]] = None,
|
160 |
truncate_dim: Optional[int] = None,
|
161 |
+
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
162 |
**kwargs,
|
163 |
):
|
164 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
|
286 |
'projections with `add_projections=True`.'
|
287 |
)
|
288 |
|
289 |
+
if (
|
290 |
+
torch_dtype
|
291 |
+
and hasattr(torch, torch_dtype)
|
292 |
+
and type(getattr(torch, torch_dtype)) is torch.dtype
|
293 |
+
):
|
294 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
295 |
+
else:
|
296 |
+
self.torch_dtype = torch_dtype
|
297 |
+
|
298 |
+
if not self.use_text_flash_attn or not torch.cuda.is_available():
|
299 |
+
self.torch_dtype = torch.float32
|
300 |
+
|
301 |
@classmethod
|
302 |
def from_text_vision_configs(
|
303 |
cls,
|