gmastrapas commited on
Commit
cd77b48
1 Parent(s): 96e41b8

feat: jina clip v2 implementation

Browse files
Files changed (7) hide show
  1. .gitignore +70 -0
  2. configuration_clip.py +0 -6
  3. eva_model.py +27 -27
  4. hf_model.py +56 -85
  5. modeling_clip.py +197 -156
  6. processing_clip.py +0 -1
  7. transform.py +95 -179
.gitignore ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project specific
2
+ __init__.py
3
+ pyproject.toml
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # Unit test / coverage reports
35
+ htmlcov/
36
+ .tox/
37
+ .nox/
38
+ .coverage
39
+ .coverage.*
40
+ .cache
41
+ nosetests.xml
42
+ coverage.xml
43
+ *.cover
44
+ *.py,cover
45
+ .hypothesis/
46
+ .pytest_cache/
47
+
48
+ # Jupyter Notebook
49
+ .ipynb_checkpoints
50
+
51
+ # IPython
52
+ profile_default/
53
+ ipython_config.py
54
+
55
+ # Environments
56
+ .env
57
+ .venv
58
+ env/
59
+ venv/
60
+ ENV/
61
+ env.bak/
62
+ venv.bak/
63
+
64
+ # mypy
65
+ .mypy_cache/
66
+ .dmypy.json
67
+ dmypy.json
68
+
69
+ # PyCharm
70
+ .idea/
configuration_clip.py CHANGED
@@ -47,11 +47,9 @@ class JinaCLIPTextConfig(PretrainedConfig):
47
  configdict, kwargs = cls.get_config_dict(
48
  pretrained_model_name_or_path, **kwargs
49
  )
50
-
51
  # get the text config dict if we are loading from JinaCLIPConfig
52
  if configdict.get('model_type') == 'jina_clip':
53
  configdict = configdict['text_config']
54
-
55
  if (
56
  'model_type' in configdict
57
  and hasattr(cls, 'model_type')
@@ -62,7 +60,6 @@ class JinaCLIPTextConfig(PretrainedConfig):
62
  f'instantiate a model of type {cls.model_type}. This is not supported '
63
  'for all configurations of models and can yield errors.'
64
  )
65
-
66
  return cls.from_dict(configdict, **kwargs)
67
 
68
 
@@ -125,11 +122,9 @@ class JinaCLIPVisionConfig(PretrainedConfig):
125
  configdict, kwargs = cls.get_config_dict(
126
  pretrained_model_name_or_path, **kwargs
127
  )
128
-
129
  # get the vision config dict if we are loading from JinaCLIPConfig
130
  if configdict.get('model_type') == 'jina_clip':
131
  configdict = configdict['vision_config']
132
-
133
  if (
134
  'model_type' in configdict
135
  and hasattr(cls, 'model_type')
@@ -140,7 +135,6 @@ class JinaCLIPVisionConfig(PretrainedConfig):
140
  f'instantiate a model of type {cls.model_type}. This is not supported '
141
  'for all configurations of models and can yield errors.'
142
  )
143
-
144
  return cls.from_dict(configdict, **kwargs)
145
 
146
 
 
47
  configdict, kwargs = cls.get_config_dict(
48
  pretrained_model_name_or_path, **kwargs
49
  )
 
50
  # get the text config dict if we are loading from JinaCLIPConfig
51
  if configdict.get('model_type') == 'jina_clip':
52
  configdict = configdict['text_config']
 
53
  if (
54
  'model_type' in configdict
55
  and hasattr(cls, 'model_type')
 
60
  f'instantiate a model of type {cls.model_type}. This is not supported '
61
  'for all configurations of models and can yield errors.'
62
  )
 
63
  return cls.from_dict(configdict, **kwargs)
64
 
65
 
 
122
  configdict, kwargs = cls.get_config_dict(
123
  pretrained_model_name_or_path, **kwargs
124
  )
 
125
  # get the vision config dict if we are loading from JinaCLIPConfig
126
  if configdict.get('model_type') == 'jina_clip':
127
  configdict = configdict['vision_config']
 
128
  if (
129
  'model_type' in configdict
130
  and hasattr(cls, 'model_type')
 
135
  f'instantiate a model of type {cls.model_type}. This is not supported '
136
  'for all configurations of models and can yield errors.'
137
  )
 
138
  return cls.from_dict(configdict, **kwargs)
139
 
140
 
eva_model.py CHANGED
@@ -9,12 +9,12 @@ from functools import partial
9
 
10
  import torch
11
  import torch.nn as nn
12
- import torch.nn.functional as F
13
 
14
  try:
15
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
  except ImportError or ModuleNotFoundError:
17
- from timm.layers import drop_path, to_2tuple, trunc_normal_
18
 
19
  from .rope_embeddings import VisionRotaryEmbeddingFast
20
 
@@ -81,7 +81,7 @@ class DropPath(nn.Module):
81
  self.drop_prob = drop_prob
82
 
83
  def forward(self, x):
84
- return drop_path(x, self.drop_prob, self.training)
85
 
86
  def extra_repr(self) -> str:
87
  return 'p={}'.format(self.drop_prob)
@@ -244,17 +244,17 @@ class Attention(nn.Module):
244
  self.rope = rope
245
 
246
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
247
- B, N, C = x.shape
248
  if self.subln:
249
- q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
250
- k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
251
- v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
252
 
253
- q = q.reshape(B, N, self.num_heads, -1).permute(
254
  0, 2, 1, 3
255
  ) # B, num_heads, N, C
256
- k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
257
- v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
258
  else:
259
  qkv_bias = None
260
  if self.q_bias is not None:
@@ -266,8 +266,8 @@ class Attention(nn.Module):
266
  )
267
  )
268
 
269
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
270
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
271
  2, 0, 3, 1, 4
272
  ) # 3, B, num_heads, N, C
273
  q, k, v = qkv[0], qkv[1], qkv[2]
@@ -298,7 +298,7 @@ class Attention(nn.Module):
298
  p=self.xattn_drop,
299
  scale=self.scale,
300
  )
301
- x = x.reshape(B, N, -1)
302
  x = self.inner_attn_ln(x)
303
  x = self.proj(x)
304
  x = self.proj_drop(x)
@@ -329,7 +329,7 @@ class Attention(nn.Module):
329
  attn = attn.softmax(dim=-1)
330
  attn = self.attn_drop(attn)
331
 
332
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
333
  x = self.inner_attn_ln(x)
334
  x = self.proj(x)
335
  x = self.proj_drop(x)
@@ -461,12 +461,12 @@ class PatchEmbed(nn.Module):
461
  in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
462
  )
463
 
464
- def forward(self, x, **kwargs):
465
  target_dtype = self.proj.weight.dtype
466
- B, C, H, W = x.shape
467
  # FIXME look at relaxing size constraints
468
- assert H == self.img_size[0] and W == self.img_size[1], (
469
- f"Input image size ({H}*{W}) doesn't match model "
470
  f'({self.img_size[0]}*{self.img_size[1]}).'
471
  )
472
  x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
@@ -559,9 +559,8 @@ class EVAVisionTransformer(nn.Module):
559
  super().__init__()
560
  self.image_size = img_size
561
  self.num_classes = num_classes
562
- self.num_features = (
563
- self.embed_dim
564
- ) = embed_dim # num_features for consistency with other models
565
 
566
  self.patch_embed = PatchEmbed(
567
  img_size=img_size,
@@ -666,8 +665,8 @@ class EVAVisionTransformer(nn.Module):
666
  self.grad_checkpointing = grad_checkpointing
667
 
668
  def fix_init_weight(self):
669
- def rescale(param, layer_id):
670
- param.div_(math.sqrt(2.0 * layer_id))
671
 
672
  for layer_id, layer in enumerate(self.blocks):
673
  rescale(layer.attn.proj.weight.data, layer_id + 1)
@@ -679,7 +678,8 @@ class EVAVisionTransformer(nn.Module):
679
  def get_cast_dtype(self) -> torch.dtype:
680
  return self.blocks[0].mlp.fc2.weight.dtype
681
 
682
- def _init_weights(self, m):
 
683
  if isinstance(m, nn.Linear):
684
  trunc_normal_(m.weight, std=0.02)
685
  if m.bias is not None:
@@ -691,7 +691,7 @@ class EVAVisionTransformer(nn.Module):
691
  def get_num_layers(self):
692
  return len(self.blocks)
693
 
694
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
695
  assert (
696
  unlocked_groups == 0
697
  ), 'partial locking not currently supported for this model'
@@ -709,7 +709,7 @@ class EVAVisionTransformer(nn.Module):
709
  def get_classifier(self):
710
  return self.head
711
 
712
- def reset_classifier(self, num_classes, global_pool=''):
713
  self.num_classes = num_classes
714
  self.head = (
715
  nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
 
9
 
10
  import torch
11
  import torch.nn as nn
12
+ import torch.nn.functional as f
13
 
14
  try:
15
+ from timm.models.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
16
  except ImportError or ModuleNotFoundError:
17
+ from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
18
 
19
  from .rope_embeddings import VisionRotaryEmbeddingFast
20
 
 
81
  self.drop_prob = drop_prob
82
 
83
  def forward(self, x):
84
+ return timm_drop_path(x, self.drop_prob, self.training)
85
 
86
  def extra_repr(self) -> str:
87
  return 'p={}'.format(self.drop_prob)
 
244
  self.rope = rope
245
 
246
  def forward(self, x, rel_pos_bias=None, attn_mask=None):
247
+ b, n, _ = x.shape
248
  if self.subln:
249
+ q = f.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
250
+ k = f.linear(input=x, weight=self.k_proj.weight, bias=None)
251
+ v = f.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
252
 
253
+ q = q.reshape(b, n, self.num_heads, -1).permute(
254
  0, 2, 1, 3
255
  ) # B, num_heads, N, C
256
+ k = k.reshape(b, n, self.num_heads, -1).permute(0, 2, 1, 3)
257
+ v = v.reshape(b, n, self.num_heads, -1).permute(0, 2, 1, 3)
258
  else:
259
  qkv_bias = None
260
  if self.q_bias is not None:
 
266
  )
267
  )
268
 
269
+ qkv = f.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
270
+ qkv = qkv.reshape(b, n, 3, self.num_heads, -1).permute(
271
  2, 0, 3, 1, 4
272
  ) # 3, B, num_heads, N, C
273
  q, k, v = qkv[0], qkv[1], qkv[2]
 
298
  p=self.xattn_drop,
299
  scale=self.scale,
300
  )
301
+ x = x.reshape(b, n, -1)
302
  x = self.inner_attn_ln(x)
303
  x = self.proj(x)
304
  x = self.proj_drop(x)
 
329
  attn = attn.softmax(dim=-1)
330
  attn = self.attn_drop(attn)
331
 
332
+ x = (attn @ v).transpose(1, 2).reshape(b, n, -1)
333
  x = self.inner_attn_ln(x)
334
  x = self.proj(x)
335
  x = self.proj_drop(x)
 
461
  in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
462
  )
463
 
464
+ def forward(self, x, **_):
465
  target_dtype = self.proj.weight.dtype
466
+ _, __, h, w = x.shape
467
  # FIXME look at relaxing size constraints
468
+ assert h == self.img_size[0] and w == self.img_size[1], (
469
+ f"Input image size ({h}*{w}) doesn't match model "
470
  f'({self.img_size[0]}*{self.img_size[1]}).'
471
  )
472
  x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
 
559
  super().__init__()
560
  self.image_size = img_size
561
  self.num_classes = num_classes
562
+ # num_features for consistency with other models
563
+ self.num_features = self.embed_dim = embed_dim
 
564
 
565
  self.patch_embed = PatchEmbed(
566
  img_size=img_size,
 
665
  self.grad_checkpointing = grad_checkpointing
666
 
667
  def fix_init_weight(self):
668
+ def rescale(param, _layer_id):
669
+ param.div_(math.sqrt(2.0 * _layer_id))
670
 
671
  for layer_id, layer in enumerate(self.blocks):
672
  rescale(layer.attn.proj.weight.data, layer_id + 1)
 
678
  def get_cast_dtype(self) -> torch.dtype:
679
  return self.blocks[0].mlp.fc2.weight.dtype
680
 
681
+ @staticmethod
682
+ def _init_weights(m):
683
  if isinstance(m, nn.Linear):
684
  trunc_normal_(m.weight, std=0.02)
685
  if m.bias is not None:
 
691
  def get_num_layers(self):
692
  return len(self.blocks)
693
 
694
+ def lock(self, unlocked_groups=0, *_, **__):
695
  assert (
696
  unlocked_groups == 0
697
  ), 'partial locking not currently supported for this model'
 
709
  def get_classifier(self):
710
  return self.head
711
 
712
+ def reset_classifier(self, num_classes, *_, **__):
713
  self.num_classes = num_classes
714
  self.head = (
715
  nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
hf_model.py CHANGED
@@ -1,6 +1,5 @@
1
  import re
2
- from typing import Dict, Optional, Tuple
3
-
4
  import torch
5
  import torch.nn as nn
6
  from transformers import AutoConfig, AutoModel, PretrainedConfig
@@ -10,9 +9,6 @@ from transformers.modeling_outputs import (
10
  BaseModelOutputWithPoolingAndCrossAttentions,
11
  )
12
 
13
- """
14
- HF architecture mapping
15
- """
16
 
17
  _HF_ARCH_DICT = {
18
  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
@@ -41,22 +37,6 @@ _HF_ARCH_DICT = {
41
  },
42
  'pooler': 'mean_pooler',
43
  },
44
- # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
45
- 'mt5': {
46
- 'config_names': {
47
- # unlimited seqlen
48
- # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
49
- # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
50
- 'context_length': '',
51
- 'vocab_size': 'vocab_size',
52
- 'width': 'd_model',
53
- 'heads': 'num_heads',
54
- 'layers': 'num_layers',
55
- 'layer_attr': 'block',
56
- 'token_embeddings_attr': 'embed_tokens',
57
- },
58
- 'pooler': 'mean_pooler',
59
- },
60
  # https://huggingface.co/docs/transformers/model_doc/bert
61
  'bert': {
62
  'config_names': {
@@ -68,24 +48,8 @@ _HF_ARCH_DICT = {
68
  },
69
  'pooler': 'cls_pooler',
70
  },
71
- # https://huggingface.co/docs/transformers/model_doc/m2m_100
72
- 'm2m_100': {
73
- 'config_names': {
74
- 'context_length': 'max_position_embeddings',
75
- 'vocab_size': 'vocab_size',
76
- 'width': 'd_model',
77
- 'heads': 'encoder_attention_heads',
78
- 'layers': 'encoder_layers',
79
- },
80
- 'pooler': 'cls_pooler',
81
- },
82
  }
83
 
84
-
85
- """
86
- Pooling functions
87
- """
88
-
89
  _POOLERS = {}
90
 
91
 
@@ -101,8 +65,6 @@ def register_pooler(cls):
101
 
102
  @register_pooler
103
  class MeanPooler(nn.Module):
104
- """Mean pooling"""
105
-
106
  @staticmethod
107
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
108
  masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
@@ -111,10 +73,6 @@ class MeanPooler(nn.Module):
111
 
112
  @register_pooler
113
  class MaxPooler(nn.Module):
114
- """
115
- Max pooling
116
- """
117
-
118
  @staticmethod
119
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
120
  masked_output = x.last_hidden_state.masked_fill(
@@ -125,11 +83,7 @@ class MaxPooler(nn.Module):
125
 
126
  @register_pooler
127
  class ClsPooler(nn.Module):
128
- """
129
- CLS token pooling
130
- """
131
-
132
- def __init__(self, use_pooler_output=True):
133
  super().__init__()
134
  self.cls_token_position = 0
135
  self.use_pooler_output = use_pooler_output
@@ -147,15 +101,9 @@ class ClsPooler(nn.Module):
147
  and (x.pooler_output is not None)
148
  ):
149
  return x.pooler_output
150
-
151
  return x.last_hidden_state[:, self.cls_token_position, :]
152
 
153
 
154
- """
155
- HF text model
156
- """
157
-
158
-
159
  class HFTextEncoder(nn.Module):
160
  output_tokens: torch.jit.Final[bool]
161
 
@@ -171,21 +119,21 @@ class HFTextEncoder(nn.Module):
171
  output_tokens: bool = False,
172
  trust_remote_code: bool = False,
173
  revision: Optional[str] = None,
 
174
  model_config_kwargs: Optional[Dict] = None,
175
  ):
176
  super().__init__()
177
  self.output_tokens = output_tokens
178
  self.output_dim = output_dim
179
 
180
- # TODO: find better way to get this information
181
- uses_transformer_pooler = pooler_type == 'cls_pooler'
182
  model_config_kwargs = model_config_kwargs or {}
183
 
184
  if config is None:
185
  self.config = AutoConfig.from_pretrained(
186
  model_name_or_path,
187
  trust_remote_code=trust_remote_code,
188
- code_revision=revision,
 
189
  )
190
  self.config.update(model_config_kwargs)
191
  create_func, model_args = (
@@ -193,34 +141,40 @@ class HFTextEncoder(nn.Module):
193
  if pretrained
194
  else (AutoModel.from_config, self.config)
195
  )
196
- # TODO: do all model configs have this attribute?
197
- # PretrainedConfig does so yes??
198
  if (
199
  hasattr(self.config, 'is_encoder_decoder')
200
  and self.config.is_encoder_decoder
201
  ):
202
- self.transformer = create_func(model_args)
 
 
 
 
 
 
203
  self.transformer = self.transformer.encoder
204
  else:
205
  self.transformer = create_func(
206
  model_args,
207
  trust_remote_code=trust_remote_code,
208
- add_pooling_layer=uses_transformer_pooler,
209
- code_revision=revision,
 
 
210
  )
211
  else:
212
  self.config = config
213
  self.config.update(model_config_kwargs)
214
- self.transformer = AutoModel.from_config(self.config)
215
-
216
- if pooler_type is None: # get default arch pooler
217
- pooler_type = _HF_ARCH_DICT[self.config.model_type]['pooler']
218
-
219
- # FIXME downstream users of OpenCLIP models use these attr,
220
- # need to verify valid across all models
221
  self.vocab_size = getattr(self.config, 'vocab_size', 0)
222
  self.context_length = getattr(self.config, 'max_position_embeddings', 0)
223
 
 
224
  self.pooler = _POOLERS[pooler_type]()
225
 
226
  d_model = getattr(
@@ -228,7 +182,7 @@ class HFTextEncoder(nn.Module):
228
  )
229
  if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
230
  self.proj = nn.Identity()
231
- elif proj_type == 'linear':
232
  self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
233
  elif proj_type == 'mlp':
234
  hidden_size = (d_model + output_dim) // 2
@@ -238,27 +192,52 @@ class HFTextEncoder(nn.Module):
238
  nn.Linear(hidden_size, output_dim, bias=proj_bias),
239
  )
240
 
241
- def forward(self, x: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  attn_mask = (x != self.config.pad_token_id).long()
243
- out = self.transformer(input_ids=x, attention_mask=attn_mask)
 
 
 
244
  pooled_out = self.pooler(out, attn_mask)
245
  projected = self.proj(pooled_out)
246
-
247
- seq_len = out.last_hidden_state.shape[1]
248
  tokens = (
249
  out.last_hidden_state[
250
- :, torch.arange(seq_len) != self.pooler.cls_token_position, :
251
  ]
252
  if isinstance(self.pooler, ClsPooler)
253
  else out.last_hidden_state
254
  )
255
-
256
  if self.output_tokens:
257
  return projected, tokens
258
  return projected
259
 
260
  def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
261
- if not unlocked_layers: # full freezing
262
  for n, p in self.transformer.named_parameters():
263
  p.requires_grad = (
264
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
@@ -287,11 +266,3 @@ class HFTextEncoder(nn.Module):
287
  p.requires_grad = (
288
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
289
  )
290
-
291
- @torch.jit.ignore
292
- def set_grad_checkpointing(self, _=True):
293
- self.transformer.gradient_checkpointing_enable()
294
-
295
- def init_parameters(self):
296
- pass
297
-
 
1
  import re
2
+ from typing import Dict, Optional
 
3
  import torch
4
  import torch.nn as nn
5
  from transformers import AutoConfig, AutoModel, PretrainedConfig
 
9
  BaseModelOutputWithPoolingAndCrossAttentions,
10
  )
11
 
 
 
 
12
 
13
  _HF_ARCH_DICT = {
14
  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
 
37
  },
38
  'pooler': 'mean_pooler',
39
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # https://huggingface.co/docs/transformers/model_doc/bert
41
  'bert': {
42
  'config_names': {
 
48
  },
49
  'pooler': 'cls_pooler',
50
  },
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
 
 
 
 
 
 
53
  _POOLERS = {}
54
 
55
 
 
65
 
66
  @register_pooler
67
  class MeanPooler(nn.Module):
 
 
68
  @staticmethod
69
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
70
  masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
 
73
 
74
  @register_pooler
75
  class MaxPooler(nn.Module):
 
 
 
 
76
  @staticmethod
77
  def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
78
  masked_output = x.last_hidden_state.masked_fill(
 
83
 
84
  @register_pooler
85
  class ClsPooler(nn.Module):
86
+ def __init__(self, use_pooler_output: bool = True):
 
 
 
 
87
  super().__init__()
88
  self.cls_token_position = 0
89
  self.use_pooler_output = use_pooler_output
 
101
  and (x.pooler_output is not None)
102
  ):
103
  return x.pooler_output
 
104
  return x.last_hidden_state[:, self.cls_token_position, :]
105
 
106
 
 
 
 
 
 
107
  class HFTextEncoder(nn.Module):
108
  output_tokens: torch.jit.Final[bool]
109
 
 
119
  output_tokens: bool = False,
120
  trust_remote_code: bool = False,
121
  revision: Optional[str] = None,
122
+ code_revision: Optional[str] = None,
123
  model_config_kwargs: Optional[Dict] = None,
124
  ):
125
  super().__init__()
126
  self.output_tokens = output_tokens
127
  self.output_dim = output_dim
128
 
 
 
129
  model_config_kwargs = model_config_kwargs or {}
130
 
131
  if config is None:
132
  self.config = AutoConfig.from_pretrained(
133
  model_name_or_path,
134
  trust_remote_code=trust_remote_code,
135
+ revision=revision,
136
+ code_revision=code_revision,
137
  )
138
  self.config.update(model_config_kwargs)
139
  create_func, model_args = (
 
141
  if pretrained
142
  else (AutoModel.from_config, self.config)
143
  )
 
 
144
  if (
145
  hasattr(self.config, 'is_encoder_decoder')
146
  and self.config.is_encoder_decoder
147
  ):
148
+ self.transformer = create_func(
149
+ model_args,
150
+ trust_remote_code=trust_remote_code,
151
+ revision=revision,
152
+ code_revision=code_revision,
153
+ **model_config_kwargs,
154
+ )
155
  self.transformer = self.transformer.encoder
156
  else:
157
  self.transformer = create_func(
158
  model_args,
159
  trust_remote_code=trust_remote_code,
160
+ revision=revision,
161
+ add_pooling_layer=False,
162
+ code_revision=code_revision,
163
+ **model_config_kwargs,
164
  )
165
  else:
166
  self.config = config
167
  self.config.update(model_config_kwargs)
168
+ self.transformer = AutoModel.from_config(
169
+ self.config,
170
+ trust_remote_code=trust_remote_code,
171
+ revision=revision,
172
+ code_revision=code_revision,
173
+ )
 
174
  self.vocab_size = getattr(self.config, 'vocab_size', 0)
175
  self.context_length = getattr(self.config, 'max_position_embeddings', 0)
176
 
177
+ pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler']
178
  self.pooler = _POOLERS[pooler_type]()
179
 
180
  d_model = getattr(
 
182
  )
183
  if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
184
  self.proj = nn.Identity()
185
+ elif (d_model != output_dim) or proj_type == 'linear':
186
  self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
187
  elif proj_type == 'mlp':
188
  hidden_size = (d_model + output_dim) // 2
 
192
  nn.Linear(hidden_size, output_dim, bias=proj_bias),
193
  )
194
 
195
+ self._task_instructions = {}
196
+ self._lora_adaptation_map = {}
197
+ self._supports_task_instructions = False
198
+ self._supports_lora = False
199
+ if (
200
+ hasattr(self.transformer, '_adaptation_map')
201
+ and len(self.transformer._adaptation_map) > 0
202
+ ):
203
+ self._lora_adaptation_map = self.transformer._adaptation_map
204
+ self._supports_lora = True
205
+ if (
206
+ hasattr(self.transformer, '_task_instructions')
207
+ and len(self.transformer._task_instructions) > 0
208
+ ):
209
+ self._task_instructions = self.transformer._task_instructions
210
+ self._supports_task_instructions = True
211
+
212
+ @torch.jit.ignore
213
+ def set_grad_checkpointing(self, _=True):
214
+ self.transformer.gradient_checkpointing_enable()
215
+
216
+ def init_parameters(self):
217
+ pass
218
+
219
+ def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
220
  attn_mask = (x != self.config.pad_token_id).long()
221
+ kwargs = {}
222
+ if adapter_mask is not None:
223
+ kwargs['adapter_mask'] = adapter_mask
224
+ out = self.transformer(input_ids=x, attention_mask=attn_mask, **kwargs)
225
  pooled_out = self.pooler(out, attn_mask)
226
  projected = self.proj(pooled_out)
227
+ seqlen = out.last_hidden_state.shape[1]
 
228
  tokens = (
229
  out.last_hidden_state[
230
+ :, torch.arange(seqlen) != self.pooler.cls_token_position, :
231
  ]
232
  if isinstance(self.pooler, ClsPooler)
233
  else out.last_hidden_state
234
  )
 
235
  if self.output_tokens:
236
  return projected, tokens
237
  return projected
238
 
239
  def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
240
+ if not unlocked_layers:
241
  for n, p in self.transformer.named_parameters():
242
  p.requires_grad = (
243
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
 
266
  p.requires_grad = (
267
  (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
268
  )
 
 
 
 
 
 
 
 
modeling_clip.py CHANGED
@@ -14,6 +14,7 @@ import requests
14
  import torch
15
  import torch.nn.functional as f
16
  import torch.utils.checkpoint
 
17
  from torch import nn
18
  from transformers import (
19
  AutoImageProcessor,
@@ -35,13 +36,12 @@ try:
35
 
36
  has_tqdm = True
37
  except ImportError:
 
38
  has_tqdm = False
39
 
40
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
41
  from .eva_model import EVAVisionTransformer
42
  from .hf_model import HFTextEncoder
43
-
44
- # needed for HF to correctly import in cache
45
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
46
  from .transform import ( # noqa: F401
47
  OPENAI_DATASET_MEAN,
@@ -157,6 +157,9 @@ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
157
  self,
158
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
159
  return_dict: Optional[bool] = None,
 
 
 
160
  *_,
161
  **__,
162
  ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
@@ -164,7 +167,12 @@ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
164
  return_dict if return_dict is not None else self.config.use_return_dict
165
  )
166
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
167
- feats = self.text_model(x=x)
 
 
 
 
 
168
  out = CLIPTextModelOutput(text_embeds=feats)
169
  return out if return_dict else out.to_tuple()
170
 
@@ -220,7 +228,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
220
  vision_config = config.vision_config
221
 
222
  if config.use_text_flash_attn is not None:
223
- text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
 
 
224
  if config.use_vision_xformers is not None:
225
  vision_config.x_attention = config.use_vision_xformers
226
 
@@ -228,13 +238,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
228
  self.projection_dim = config.projection_dim
229
  self.text_embed_dim = text_config.embed_dim
230
  self.vision_embed_dim = vision_config.embed_dim
231
-
232
  self.text_model = _build_text_tower(text_config)
233
  self.vision_model = _build_vision_tower(vision_config)
234
  self.logit_scale = nn.Parameter(
235
  torch.tensor(self.config.logit_scale_init_value)
236
  )
237
-
238
  if self.add_projections:
239
  self.visual_projection = nn.Linear(
240
  self.vision_embed_dim, self.projection_dim, bias=False
@@ -267,11 +275,12 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
267
  def get_text_features(
268
  self,
269
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
 
270
  *_,
271
  **__,
272
  ) -> torch.FloatTensor:
273
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
274
- return self.text_projection(self.text_model(x=x))
275
 
276
  def get_image_features(
277
  self,
@@ -286,24 +295,24 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
286
  )
287
  return self.visual_projection(self.vision_model(x=x))
288
 
289
- def truncate_embeddings(self, embeddings, truncate_dim):
290
  if not self.config.matryoshka_dimensions:
291
  logger.warning(
292
- "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
293
- )
294
- return embeddings
295
- elif truncate_dim in self.config.matryoshka_dimensions:
296
- return embeddings[:, :truncate_dim]
297
- else:
298
- raise ValueError(
299
- f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
300
- f"Supported dimensions are {self.config.matryoshka_dimensions}."
301
  )
 
 
 
 
 
 
 
302
 
303
  @torch.inference_mode()
304
- def encode_text(
305
  self,
306
- sentences: Union[str, List[str]],
307
  batch_size: int = 32,
308
  show_progress_bar: Optional[bool] = None,
309
  convert_to_numpy: bool = True,
@@ -311,122 +320,129 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
311
  device: Optional[torch.device] = None,
312
  normalize_embeddings: bool = True,
313
  truncate_dim: Optional[int] = None,
314
- **tokenizer_kwargs,
315
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
316
  """
317
- Computes sentence embeddings
318
- Args:
319
- sentences(`str` or `List[str]`):
320
- Sentence or sentences to be encoded
321
- batch_size(`int`, *optional*, defaults to 32):
322
- Batch size for the computation
323
- show_progress_bar(`bool`, *optional*, defaults to None):
324
- Show a progress bar when encoding sentences.
325
- If set to None, progress bar is only shown when
326
- `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
327
- convert_to_numpy(`bool`, *optional*, defaults to True):
328
- If true, the output is a list of numpy vectors.
329
- Else, it is a list of pytorch tensors.
330
- convert_to_tensor(`bool`, *optional*, defaults to False):
331
- If true, you get one large tensor as return.
332
- Overwrites any setting from convert_to_numpy
333
- device(`torch.device`, *optional*, defaults to None):
334
- Which torch.device to use for the computation
335
- normalize_embeddings(`bool`, *optional*, defaults to False):
336
- If set to true, returned vectors will have length 1. In that case,
337
- the faster dot-product (util.dot_score) instead of cosine similarity
338
- can be used.
339
- truncate_dim(`int`, *optional*, defaults to None):
340
- The dimension to truncate sentence embeddings to. `None` does no truncation.
341
- tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
342
- Keyword arguments for the tokenizer
343
- Returns:
344
- By default, a list of tensors is returned.
345
- If convert_to_tensor, a stacked tensor is returned.
346
- If convert_to_numpy, a numpy matrix is returned.
347
  """
348
- is_training = self.training
 
349
  self.eval()
350
- all_embeddings = []
351
 
352
- self.tokenizer = self.get_tokenizer()
 
353
 
354
  if show_progress_bar is None:
355
  show_progress_bar = (
356
  logger.getEffectiveLevel() == logging.INFO
357
  or logger.getEffectiveLevel() == logging.DEBUG
358
  )
359
-
360
  if convert_to_tensor:
361
  convert_to_numpy = False
362
 
363
- input_was_string = False
364
- if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
365
- sentences = [sentences]
366
- input_was_string = True
367
 
368
  if device is not None:
369
  self.to(device)
370
 
371
- permutation = np.argsort([-len(i) for i in sentences])
372
- inverse_permutation = np.argsort(permutation)
373
- sentences = [sentences[idx] for idx in permutation]
374
-
375
- tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
376
- tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
377
- tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
378
 
379
  if has_tqdm:
380
  range_iter = trange(
381
  0,
382
- len(sentences),
383
  batch_size,
384
  desc='Encoding',
385
  disable=not show_progress_bar,
386
  )
387
  else:
388
- range_iter = range(0, len(sentences), batch_size)
389
 
390
  truncate_dim = truncate_dim or self.config.truncate_dim
 
391
  for i in range_iter:
392
- encoded_input = self.tokenizer(
393
- sentences[i : i + batch_size],
394
- return_tensors='pt',
395
- **tokenizer_kwargs,
396
- ).to(self.device)
 
 
 
 
 
 
 
 
 
 
397
 
398
- embeddings = self.get_text_features(input_ids=encoded_input)
 
 
399
 
400
  if truncate_dim:
401
  embeddings = self.truncate_embeddings(embeddings, truncate_dim)
402
  if normalize_embeddings:
403
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
404
  if convert_to_numpy:
405
  embeddings = embeddings.cpu()
 
406
  all_embeddings.extend(embeddings)
407
 
408
- all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
409
 
410
  if convert_to_tensor:
411
  all_embeddings = torch.stack(all_embeddings)
412
  elif convert_to_numpy:
413
- all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
 
 
414
 
415
- if input_was_string:
416
  all_embeddings = all_embeddings[0]
417
 
418
- self.train(is_training)
419
  return all_embeddings
420
 
421
- def decode_data_image(data_image_str):
422
- header, data = data_image_str.split(',', 1)
423
- image_data = base64.b64decode(data)
424
- return Image.open(BytesIO(image_data))
425
-
426
  @torch.inference_mode()
427
- def encode_image(
428
  self,
429
- images: Union[str, List[Union[str, "Image.Image"]]],
 
430
  batch_size: int = 32,
431
  show_progress_bar: Optional[bool] = None,
432
  convert_to_numpy: bool = True,
@@ -434,129 +450,153 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
434
  device: Optional[torch.device] = None,
435
  normalize_embeddings: bool = True,
436
  truncate_dim: Optional[int] = None,
 
437
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
438
  """
439
- Computes image embeddings.
440
-
441
  Args:
442
- images(`str` or `List[Union[str, Image.Image]]`):
443
- image paths, URLs, PIL images, or data:image/ strings to be encoded
 
 
 
 
444
  batch_size(`int`, *optional*, defaults to 32):
445
  Batch size for the computation
446
  show_progress_bar(`bool`, *optional*, defaults to None):
447
- Show a progress bar when encoding images.
448
- If set to None, progress bar is only shown when
449
- `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
450
  convert_to_numpy(`bool`, *optional*, defaults to True):
451
- If true, the output is a list of numpy vectors.
452
- Else, it is a list of pytorch tensors.
453
  convert_to_tensor(`bool`, *optional*, defaults to False):
454
- If true, you get one large tensor as return.
455
- Overwrites any setting from convert_to_numpy
456
  device(`torch.device`, *optional*, defaults to None):
457
  Which torch.device to use for the computation
458
  normalize_embeddings(`bool`, *optional*, defaults to False):
459
  If set to true, returned vectors will have length 1. In that case,
460
  the faster dot-product (util.dot_score) instead of cosine similarity
461
- can be used.
462
  truncate_dim(`int`, *optional*, defaults to None):
463
- The dimension to truncate sentence embeddings to. `None` does no truncation.
 
 
 
464
  Returns:
465
- By default, a list of tensors is returned.
466
- If convert_to_tensor, a stacked tensor is returned.
467
- If convert_to_numpy, a numpy matrix is returned.
468
  """
469
-
470
- is_training = self.training
471
  self.eval()
472
-
473
- self.preprocess = self.get_preprocess()
474
  all_embeddings = []
475
-
 
476
  if show_progress_bar is None:
477
  show_progress_bar = (
478
  logger.getEffectiveLevel() == logging.INFO
479
  or logger.getEffectiveLevel() == logging.DEBUG
480
  )
481
-
482
  if convert_to_tensor:
483
  convert_to_numpy = False
484
-
485
- input_was_single_img = False
486
- if isinstance(images, str) or not hasattr(images, '__len__'):
487
- images = [images]
488
- input_was_single_img = True
489
-
490
  if device is not None:
491
  self.to(device)
492
-
493
- permutation = np.argsort([-len(str(i)) for i in images])
494
- inverse_permutation = np.argsort(permutation)
495
- images = [images[idx] for idx in permutation]
496
-
 
 
 
 
497
  if has_tqdm:
498
  range_iter = trange(
499
  0,
500
- len(images),
501
  batch_size,
502
  desc='Encoding',
503
  disable=not show_progress_bar,
504
  )
505
  else:
506
- range_iter = range(0, len(images), batch_size)
507
-
508
- from PIL import Image
509
 
510
  truncate_dim = truncate_dim or self.config.truncate_dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  for i in range_iter:
512
- batch_images = images[i:i+batch_size]
513
- processed_inputs = []
514
-
515
- for img in batch_images:
516
- if isinstance(img, str):
517
- if img.startswith('http'):
518
- response = requests.get(img)
519
- image = Image.open(BytesIO(response.content)).convert('RGB')
520
- elif img.startswith('data:image/'):
521
- image = decode_data_image(img).convert('RGB')
522
- else:
523
- image = Image.open(img).convert('RGB')
524
- elif isinstance(img, Image.Image):
525
- image = img.convert('RGB')
526
- else:
527
- raise ValueError("Unsupported image format")
528
-
529
- processed_inputs.append(image)
530
-
531
- processed_inputs = self.preprocess(processed_inputs)
532
- processed_inputs = processed_inputs.to(self.device)
533
- embeddings = self.get_image_features(processed_inputs)
534
-
535
  if truncate_dim:
536
  embeddings = self.truncate_embeddings(embeddings, truncate_dim)
537
  if normalize_embeddings:
538
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
539
  if convert_to_numpy:
540
  embeddings = embeddings.cpu()
541
  all_embeddings.extend(embeddings)
542
-
543
- all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
544
-
545
  if convert_to_tensor:
546
  all_embeddings = torch.stack(all_embeddings)
547
  elif convert_to_numpy:
548
- all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
549
-
550
- if input_was_single_img:
 
551
  all_embeddings = all_embeddings[0]
552
-
553
- self.train(is_training)
554
  return all_embeddings
555
 
556
  def forward(
557
  self,
558
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
559
  pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
 
560
  return_dict: Optional[bool] = None,
561
  return_loss: Optional[bool] = None,
562
  *_,
@@ -566,8 +606,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
566
  return_dict if return_dict is not None else self.config.use_return_dict
567
  )
568
  image_embeds = self.get_image_features(pixel_values=pixel_values)
569
- text_embeds = self.get_text_features(input_ids=input_ids)
570
-
 
571
  # normalized features
572
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
573
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
 
14
  import torch
15
  import torch.nn.functional as f
16
  import torch.utils.checkpoint
17
+ from PIL import Image
18
  from torch import nn
19
  from transformers import (
20
  AutoImageProcessor,
 
36
 
37
  has_tqdm = True
38
  except ImportError:
39
+ trange = None
40
  has_tqdm = False
41
 
42
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
43
  from .eva_model import EVAVisionTransformer
44
  from .hf_model import HFTextEncoder
 
 
45
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
46
  from .transform import ( # noqa: F401
47
  OPENAI_DATASET_MEAN,
 
157
  self,
158
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
159
  return_dict: Optional[bool] = None,
160
+ use_lora: bool = False,
161
+ adapter_mask: Optional[torch.Tensor] = None,
162
+ task: Optional[str] = None,
163
  *_,
164
  **__,
165
  ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
 
167
  return_dict if return_dict is not None else self.config.use_return_dict
168
  )
169
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
170
+ feats = self.text_model(
171
+ x=x,
172
+ use_lora=use_lora,
173
+ adapter_mask=adapter_mask,
174
+ task=task,
175
+ )
176
  out = CLIPTextModelOutput(text_embeds=feats)
177
  return out if return_dict else out.to_tuple()
178
 
 
228
  vision_config = config.vision_config
229
 
230
  if config.use_text_flash_attn is not None:
231
+ text_config.hf_model_config_kwargs['use_flash_attn'] = (
232
+ config.use_text_flash_attn
233
+ )
234
  if config.use_vision_xformers is not None:
235
  vision_config.x_attention = config.use_vision_xformers
236
 
 
238
  self.projection_dim = config.projection_dim
239
  self.text_embed_dim = text_config.embed_dim
240
  self.vision_embed_dim = vision_config.embed_dim
 
241
  self.text_model = _build_text_tower(text_config)
242
  self.vision_model = _build_vision_tower(vision_config)
243
  self.logit_scale = nn.Parameter(
244
  torch.tensor(self.config.logit_scale_init_value)
245
  )
 
246
  if self.add_projections:
247
  self.visual_projection = nn.Linear(
248
  self.vision_embed_dim, self.projection_dim, bias=False
 
275
  def get_text_features(
276
  self,
277
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
278
+ adapter_mask: Optional[torch.Tensor] = None,
279
  *_,
280
  **__,
281
  ) -> torch.FloatTensor:
282
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
283
+ return self.text_projection(self.text_model(x=x, adapter_mask=adapter_mask))
284
 
285
  def get_image_features(
286
  self,
 
295
  )
296
  return self.visual_projection(self.vision_model(x=x))
297
 
298
+ def _truncate_embeddings(self, embeddings: torch.Tensor, truncate_dim: int):
299
  if not self.config.matryoshka_dimensions:
300
  logger.warning(
301
+ 'Model is not trained using Matryoshka Representation Learning, '
302
+ 'truncating embeddings will not work optimally.'
 
 
 
 
 
 
 
303
  )
304
+ return embeddings[:, :truncate_dim]
305
+
306
+ @staticmethod
307
+ def _decode_image_data(image_data_str: str) -> Image:
308
+ header, data = image_data_str.split(',', 1)
309
+ image_data = base64.b64decode(data)
310
+ return Image.open(BytesIO(image_data))
311
 
312
  @torch.inference_mode()
313
+ def encode_image(
314
  self,
315
+ images: Union[str, List[Union[str, 'Image.Image']]],
316
  batch_size: int = 32,
317
  show_progress_bar: Optional[bool] = None,
318
  convert_to_numpy: bool = True,
 
320
  device: Optional[torch.device] = None,
321
  normalize_embeddings: bool = True,
322
  truncate_dim: Optional[int] = None,
 
323
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
324
  """
325
+ Computes image embeddings
326
+
327
+ Args:
328
+ images(`str` or `List[Union[str, Image.Image]]`):
329
+ Image paths, URLs, PIL images, or data:image/ strings to be encoded
330
+ batch_size(`int`, *optional*, defaults to 32):
331
+ Batch size for the computation
332
+ show_progress_bar(`bool`, *optional*, defaults to None):
333
+ Show a progress bar when encoding images. If set to None, progress bar
334
+ is only shown when `logger.level == logging.INFO` or
335
+ `logger.level == logging.DEBUG`
336
+ convert_to_numpy(`bool`, *optional*, defaults to True):
337
+ If true, the output is a list of numpy vectors. Else, it is a list of
338
+ pytorch tensors
339
+ convert_to_tensor(`bool`, *optional*, defaults to False):
340
+ If true, you get one large tensor as return. Overwrites any setting
341
+ from convert_to_numpy
342
+ device(`torch.device`, *optional*, defaults to None):
343
+ Which torch.device to use for the computation
344
+ normalize_embeddings(`bool`, *optional*, defaults to False):
345
+ If set to true, returned vectors will have length 1. In that case,
346
+ the faster dot-product (util.dot_score) instead of cosine similarity
347
+ can be used
348
+ truncate_dim(`int`, *optional*, defaults to None):
349
+ The dimension to truncate sentence embeddings to. If set to `None`
350
+ no truncation is performed
351
+
352
+ Returns:
353
+ By default, a list of tensors is returned. If convert_to_tensor, a stacked
354
+ tensor is returned. If convert_to_numpy, a numpy matrix is returned
355
  """
356
+
357
+ _is_training = self.training
358
  self.eval()
 
359
 
360
+ self.preprocess = self.get_preprocess()
361
+ all_embeddings = []
362
 
363
  if show_progress_bar is None:
364
  show_progress_bar = (
365
  logger.getEffectiveLevel() == logging.INFO
366
  or logger.getEffectiveLevel() == logging.DEBUG
367
  )
 
368
  if convert_to_tensor:
369
  convert_to_numpy = False
370
 
371
+ _input_was_single_img = False
372
+ if isinstance(images, str) or not hasattr(images, '__len__'):
373
+ images = [images]
374
+ _input_was_single_img = True
375
 
376
  if device is not None:
377
  self.to(device)
378
 
379
+ _permutation = np.argsort([-len(str(i)) for i in images])
380
+ _inverse_permutation = np.argsort(_permutation)
381
+ images = [images[idx] for idx in _permutation]
 
 
 
 
382
 
383
  if has_tqdm:
384
  range_iter = trange(
385
  0,
386
+ len(images),
387
  batch_size,
388
  desc='Encoding',
389
  disable=not show_progress_bar,
390
  )
391
  else:
392
+ range_iter = range(0, len(images), batch_size)
393
 
394
  truncate_dim = truncate_dim or self.config.truncate_dim
395
+
396
  for i in range_iter:
397
+ _processed_images = []
398
+ for img in images[i: i + batch_size]:
399
+ if isinstance(img, str):
400
+ if img.startswith('http'):
401
+ response = requests.get(img)
402
+ image = Image.open(BytesIO(response.content)).convert('RGB')
403
+ elif img.startswith('data:image/'):
404
+ image = self._decode_image_data(img).convert('RGB')
405
+ else:
406
+ image = Image.open(img).convert('RGB')
407
+ elif isinstance(img, Image.Image):
408
+ image = img.convert('RGB')
409
+ else:
410
+ raise ValueError('Unsupported image format')
411
+ _processed_images.append(image)
412
 
413
+ pixelvals = self.preprocess(_processed_images)
414
+ pixelvals = pixelvals.to(self.device)
415
+ embeddings = self.get_image_features(pixelvals)
416
 
417
  if truncate_dim:
418
  embeddings = self.truncate_embeddings(embeddings, truncate_dim)
419
  if normalize_embeddings:
420
+ embeddings = f.normalize(embeddings, p=2, dim=1)
421
  if convert_to_numpy:
422
  embeddings = embeddings.cpu()
423
+
424
  all_embeddings.extend(embeddings)
425
 
426
+ all_embeddings = [all_embeddings[idx] for idx in _inverse_permutation]
427
 
428
  if convert_to_tensor:
429
  all_embeddings = torch.stack(all_embeddings)
430
  elif convert_to_numpy:
431
+ all_embeddings = np.asarray(
432
+ [emb.to(torch.float32).numpy() for emb in all_embeddings]
433
+ )
434
 
435
+ if _input_was_single_img:
436
  all_embeddings = all_embeddings[0]
437
 
438
+ self.train(_is_training)
439
  return all_embeddings
440
 
 
 
 
 
 
441
  @torch.inference_mode()
442
+ def encode_text(
443
  self,
444
+ sentences: Union[str, List[str]],
445
+ task: Optional[str] = None,
446
  batch_size: int = 32,
447
  show_progress_bar: Optional[bool] = None,
448
  convert_to_numpy: bool = True,
 
450
  device: Optional[torch.device] = None,
451
  normalize_embeddings: bool = True,
452
  truncate_dim: Optional[int] = None,
453
+ **tokenizer_kwargs,
454
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
455
  """
456
+ Computes text embeddings
457
+
458
  Args:
459
+ sentences(`str` or `List[str]`):
460
+ Sentence or sentences to be encoded
461
+ task(`str`, *optional*, defaults to `None`):
462
+ Specifies the task for which the encoding is intended. If `task` is
463
+ not provided, all LoRA adapters are disabled, and the model reverts
464
+ to its original, general-purpose weights
465
  batch_size(`int`, *optional*, defaults to 32):
466
  Batch size for the computation
467
  show_progress_bar(`bool`, *optional*, defaults to None):
468
+ Show a progress bar when encoding sentences. If set to None, progress
469
+ bar is only shown when `logger.level == logging.INFO` or
470
+ `logger.level == logging.DEBUG`
471
  convert_to_numpy(`bool`, *optional*, defaults to True):
472
+ If true, the output is a list of numpy vectors. Else, it is a list of
473
+ pytorch tensors
474
  convert_to_tensor(`bool`, *optional*, defaults to False):
475
+ If true, you get one large tensor as return. Overwrites any setting
476
+ from convert_to_numpy
477
  device(`torch.device`, *optional*, defaults to None):
478
  Which torch.device to use for the computation
479
  normalize_embeddings(`bool`, *optional*, defaults to False):
480
  If set to true, returned vectors will have length 1. In that case,
481
  the faster dot-product (util.dot_score) instead of cosine similarity
482
+ can be used
483
  truncate_dim(`int`, *optional*, defaults to None):
484
+ The dimension to truncate sentence embeddings to. If set to `None`
485
+ no truncation is performed
486
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
487
+ Keyword arguments for the tokenizer
488
  Returns:
489
+ By default, a list of tensors is returned. If convert_to_tensor, a stacked
490
+ tensor is returned. If convert_to_numpy, a numpy matrix is returned.
 
491
  """
492
+ _is_training = self.training
 
493
  self.eval()
494
+
 
495
  all_embeddings = []
496
+ self.tokenizer = self.get_tokenizer()
497
+
498
  if show_progress_bar is None:
499
  show_progress_bar = (
500
  logger.getEffectiveLevel() == logging.INFO
501
  or logger.getEffectiveLevel() == logging.DEBUG
502
  )
 
503
  if convert_to_tensor:
504
  convert_to_numpy = False
505
+
506
+ _input_was_string = False
507
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
508
+ sentences = [sentences]
509
+ _input_was_string = True
510
+
511
  if device is not None:
512
  self.to(device)
513
+
514
+ _permutation = np.argsort([-len(i) for i in sentences])
515
+ _inverse_permutation = np.argsort(_permutation)
516
+ sentences = [sentences[idx] for idx in _permutation]
517
+
518
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
519
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
520
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
521
+
522
  if has_tqdm:
523
  range_iter = trange(
524
  0,
525
+ len(sentences),
526
  batch_size,
527
  desc='Encoding',
528
  disable=not show_progress_bar,
529
  )
530
  else:
531
+ range_iter = range(0, len(sentences), batch_size)
 
 
532
 
533
  truncate_dim = truncate_dim or self.config.truncate_dim
534
+
535
+ adapter_mask = None
536
+ if task:
537
+ if not self.text_model._supports_lora:
538
+ logger.warning('Text tower does not support LoRA task adaptation')
539
+ elif task not in self.text_model._lora_adaptation_map:
540
+ raise ValueError(
541
+ f'Unsupported task \'{task}\'. Choose one of the following: '
542
+ f'{", ".join(self.text_model._lora_adaptation_map)} or bypass the '
543
+ '`task` argument to disable LoRA completely.'
544
+ )
545
+ else:
546
+ taskid = self.text_model._lora_adaptation_map[task]
547
+ nexamples = 1 if isinstance(sentences, str) else len(sentences)
548
+ adapter_mask = torch.full(
549
+ (nexamples,), taskid, dtype=torch.int32, device=self.device
550
+ )
551
+ if not self.text_model._supports_task_instructions:
552
+ logger.warning('Text tower does not support task instructions')
553
+ elif task not in self.text_model._task_instructions:
554
+ raise ValueError(
555
+ f'Unsupported task \'{task}\'. Choose one of the following: '
556
+ f'{", ".join(self.text_model._task_instructions)} or bypass the '
557
+ '`task` argument to disable task instructions completely.'
558
+ )
559
+ else:
560
+ instruction = self.text_model._task_instructions[task]
561
+ sentences = [instruction + sentence for sentence in sentences]
562
+
563
  for i in range_iter:
564
+ tokens = self.tokenizer(
565
+ sentences[i: i + batch_size],
566
+ return_tensors='pt',
567
+ **tokenizer_kwargs,
568
+ ).to(self.device)
569
+
570
+ embeddings = self.get_text_features(
571
+ input_ids=tokens, adapter_mask=adapter_mask
572
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  if truncate_dim:
574
  embeddings = self.truncate_embeddings(embeddings, truncate_dim)
575
  if normalize_embeddings:
576
+ embeddings = f.normalize(embeddings, p=2, dim=1)
577
  if convert_to_numpy:
578
  embeddings = embeddings.cpu()
579
  all_embeddings.extend(embeddings)
580
+
581
+ all_embeddings = [all_embeddings[idx] for idx in _inverse_permutation]
582
+
583
  if convert_to_tensor:
584
  all_embeddings = torch.stack(all_embeddings)
585
  elif convert_to_numpy:
586
+ all_embeddings = np.asarray(
587
+ [emb.to(torch.float32).numpy() for emb in all_embeddings]
588
+ )
589
+ if _input_was_string:
590
  all_embeddings = all_embeddings[0]
591
+
592
+ self.train(_is_training)
593
  return all_embeddings
594
 
595
  def forward(
596
  self,
597
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
598
  pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
599
+ adapter_mask: Optional[torch.Tensor] = None,
600
  return_dict: Optional[bool] = None,
601
  return_loss: Optional[bool] = None,
602
  *_,
 
606
  return_dict if return_dict is not None else self.config.use_return_dict
607
  )
608
  image_embeds = self.get_image_features(pixel_values=pixel_values)
609
+ text_embeds = self.get_text_features(
610
+ input_ids=input_ids, adapter_mask=adapter_mask
611
+ )
612
  # normalized features
613
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
614
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
processing_clip.py CHANGED
@@ -72,7 +72,6 @@ class JinaCLIPImageProcessor(BaseImageProcessor):
72
  return output
73
 
74
  def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
75
-
76
  _transform_needs_rebuild = False
77
  for k, v in kwargs.items():
78
  if k in self._valid_processor_keys:
 
72
  return output
73
 
74
  def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
 
75
  _transform_needs_rebuild = False
76
  for k, v in kwargs.items():
77
  if k in self._valid_processor_keys:
transform.py CHANGED
@@ -1,11 +1,10 @@
1
- import numbers
2
  import random
3
  import warnings
4
  from dataclasses import asdict, dataclass
5
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
 
7
  import torch
8
- import torchvision.transforms.functional as F
9
  from torchvision.transforms import (
10
  CenterCrop,
11
  ColorJitter,
@@ -23,88 +22,93 @@ OPENAI_DATASET_MEAN = tuple(OPENAI_CLIP_MEAN)
23
  OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
24
 
25
 
26
- @dataclass
27
- class PreprocessCfg:
28
- size: Union[int, Tuple[int, int]] = 224
29
- mode: str = 'RGB'
30
- mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
31
- std: Tuple[float, ...] = OPENAI_DATASET_STD
32
- interpolation: str = 'bicubic'
33
- resize_mode: str = 'shortest'
34
- fill_color: int = 0
35
-
36
- def __post_init__(self):
37
- assert self.mode in ('RGB',)
38
-
39
- @property
40
- def num_channels(self):
41
- return 3
42
-
43
- @property
44
- def input_size(self):
45
- return (self.num_channels,) + (self.size, self.size)
46
-
47
-
48
- _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
49
 
50
 
51
- def merge_preprocess_dict(
52
- base: Union[PreprocessCfg, Dict],
53
- overlay: Dict,
54
- ):
55
- """Merge overlay key-value pairs on top of base preprocess cfg or dict.
56
- Input dicts are filtered based on PreprocessCfg fields.
57
  """
58
- if isinstance(base, PreprocessCfg):
59
- base_clean = asdict(base)
60
- else:
61
- base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
62
- if overlay:
63
- overlay_clean = {
64
- k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None
65
- }
66
- base_clean.update(overlay_clean)
67
- return base_clean
68
 
 
 
69
 
70
- def merge_preprocess_kwargs(base: Union[PreprocessCfg, Dict], **kwargs):
71
- return merge_preprocess_dict(base, kwargs)
 
 
 
 
 
 
 
 
 
72
 
 
 
 
73
 
74
- @dataclass
75
- class AugmentationCfg:
76
- scale: Tuple[float, float] = (0.9, 1.0)
77
- ratio: Optional[Tuple[float, float]] = None
78
- color_jitter: Optional[
79
- Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
80
- ] = None
81
- re_prob: Optional[float] = None
82
- re_count: Optional[int] = None
83
- use_timm: bool = False
84
 
85
- # params for simclr_jitter_gray
86
- color_jitter_prob: float = None
87
- gray_scale_prob: float = None
 
 
 
88
 
 
 
 
 
 
 
89
 
90
- def _setup_size(size, error_msg):
91
- if isinstance(size, numbers.Number):
92
- return int(size), int(size)
 
 
 
93
 
94
- if isinstance(size, Sequence) and len(size) == 1:
95
- return size[0], size[0]
 
 
96
 
97
- if len(size) != 2:
98
- raise ValueError(error_msg)
 
 
99
 
100
- return size
 
101
 
102
 
103
- class ResizeKeepRatio:
104
- """Resize and Keep Ratio
105
 
106
- Copy & paste from `timm`
107
- """
 
108
 
109
  def __init__(
110
  self,
@@ -159,8 +163,9 @@ class ResizeKeepRatio:
159
  ratio_factor[0] / aspect_factor,
160
  ratio_factor[1] * aspect_factor,
161
  )
162
- size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
163
- return size
 
164
 
165
  def __call__(self, img):
166
  """
@@ -180,7 +185,7 @@ class ResizeKeepRatio:
180
  self.random_aspect_prob,
181
  self.random_aspect_range,
182
  )
183
- img = F.resize(img, size, self.interpolation)
184
  return img
185
 
186
  def __repr__(self):
@@ -190,92 +195,8 @@ class ResizeKeepRatio:
190
  return format_string
191
 
192
 
193
- def center_crop_or_pad(
194
- img: torch.Tensor, output_size: List[int], fill=0
195
- ) -> torch.Tensor:
196
- """Center crops and/or pads the given image.
197
- If the image is torch Tensor, it is expected
198
- to have [..., H, W] shape, where ... means an arbitrary number of leading
199
- dimensions. If image size is smaller than output size along any edge, image is
200
- padded with 0 and then center cropped.
201
-
202
- Args:
203
- img (PIL Image or Tensor): Image to be cropped.
204
- output_size (sequence or int): (height, width) of the crop box. If int or
205
- sequence with single int, it is used for both directions.
206
- fill (int, Tuple[int]): Padding color
207
-
208
- Returns:
209
- PIL Image or Tensor: Cropped image.
210
- """
211
- if isinstance(output_size, numbers.Number):
212
- output_size = (int(output_size), int(output_size))
213
- elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
214
- output_size = (output_size[0], output_size[0])
215
-
216
- _, image_height, image_width = F.get_dimensions(img)
217
- crop_height, crop_width = output_size
218
-
219
- if crop_width > image_width or crop_height > image_height:
220
- padding_ltrb = [
221
- (crop_width - image_width) // 2 if crop_width > image_width else 0,
222
- (crop_height - image_height) // 2 if crop_height > image_height else 0,
223
- (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
224
- (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
225
- ]
226
- img = F.pad(img, padding_ltrb, fill=fill)
227
- _, image_height, image_width = F.get_dimensions(img)
228
- if crop_width == image_width and crop_height == image_height:
229
- return img
230
-
231
- crop_top = int(round((image_height - crop_height) / 2.0))
232
- crop_left = int(round((image_width - crop_width) / 2.0))
233
- return F.crop(img, crop_top, crop_left, crop_height, crop_width)
234
-
235
-
236
- class CenterCropOrPad(torch.nn.Module):
237
- """Crops the given image at the center.
238
- If the image is torch Tensor, it is expected
239
- to have [..., H, W] shape, where ... means an arbitrary number of leading
240
- dimensions. If image size is smaller than output size along any edge, image is
241
- padded with 0 and then center cropped.
242
-
243
- Args:
244
- size (sequence or int): Desired output size of the crop. If size is an
245
- int instead of sequence like (h, w), a square crop (size, size) is
246
- made. If provided a sequence of length 1, it will be interpreted as
247
- (size[0], size[0]).
248
- """
249
-
250
- def __init__(self, size, fill=0):
251
- super().__init__()
252
- self.size = _setup_size(
253
- size, error_msg='Please provide only two dimensions (h, w) for size.'
254
- )
255
- self.fill = fill
256
-
257
- def forward(self, img):
258
- """
259
- Args:
260
- img (PIL Image or Tensor): Image to be cropped.
261
-
262
- Returns:
263
- PIL Image or Tensor: Cropped image.
264
- """
265
- return center_crop_or_pad(img, self.size, fill=self.fill)
266
-
267
- def __repr__(self) -> str:
268
- return f'{self.__class__.__name__}(size={self.size})'
269
-
270
-
271
- def _convert_to_rgb(image):
272
- return image.convert('RGB')
273
-
274
-
275
  class _ColorJitter(object):
276
- """
277
- Apply Color Jitter to the PIL image with a specified probability.
278
- """
279
 
280
  def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
281
  assert 0.0 <= p <= 1.0
@@ -292,9 +213,7 @@ class _ColorJitter(object):
292
 
293
 
294
  class _GrayScale(object):
295
- """
296
- Apply Gray Scale to the PIL image with a specified probability.
297
- """
298
 
299
  def __init__(self, p=0.2):
300
  assert 0.0 <= p <= 1.0
@@ -308,6 +227,20 @@ class _GrayScale(object):
308
  return img
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def image_transform(
312
  image_size: Union[int, Tuple[int, int]],
313
  is_train: bool,
@@ -407,10 +340,10 @@ def image_transform(
407
  else:
408
  if resize_mode == 'longest':
409
  transforms = [
410
- ResizeKeepRatio(
411
  image_size, interpolation=interpolation_mode, longest=1
412
  ),
413
- CenterCropOrPad(image_size, fill=fill_color),
414
  ]
415
  elif resize_mode == 'squash':
416
  if isinstance(image_size, int):
@@ -428,7 +361,7 @@ def image_transform(
428
  transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
429
  else:
430
  # resize shortest edge to matching target dim for non-square target
431
- transforms = [ResizeKeepRatio(image_size)]
432
  transforms += [CenterCrop(image_size)]
433
 
434
  transforms.extend(
@@ -439,20 +372,3 @@ def image_transform(
439
  ]
440
  )
441
  return Compose(transforms)
442
-
443
-
444
- def image_transform_v2(
445
- cfg: PreprocessCfg,
446
- is_train: bool,
447
- aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
448
- ):
449
- return image_transform(
450
- image_size=cfg.size,
451
- is_train=is_train,
452
- mean=cfg.mean,
453
- std=cfg.std,
454
- interpolation=cfg.interpolation,
455
- resize_mode=cfg.resize_mode,
456
- fill_color=cfg.fill_color,
457
- aug_cfg=aug_cfg,
458
- )
 
 
1
  import random
2
  import warnings
3
  from dataclasses import asdict, dataclass
4
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
5
 
6
  import torch
7
+ import torchvision.transforms.functional as f
8
  from torchvision.transforms import (
9
  CenterCrop,
10
  ColorJitter,
 
22
  OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
23
 
24
 
25
+ def _setup_size(size, error_msg):
26
+ if isinstance(size, int):
27
+ return size, size
28
+ if isinstance(size, Sequence) and len(size) == 1:
29
+ return size[0], size[0]
30
+ if len(size) != 2:
31
+ raise ValueError(error_msg)
32
+ return size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
+ def _center_crop_or_pad(
36
+ img: torch.Tensor,
37
+ output_size: Union[int, Tuple[int, ...], List[int]],
38
+ fill: Union[int, Tuple[int]] = 0,
39
+ ) -> torch.Tensor:
 
40
  """
41
+ Center crops and/or pads the given image. If the image is torch Tensor, it is
42
+ expected to have [..., H, W] shape, where ... means an arbitrary number of leading
43
+ dimensions. If image size is smaller than output size along any edge, image is
44
+ padded with 0 and then center cropped.
45
+ """
46
+ if isinstance(output_size, int):
47
+ output_size = (output_size, output_size)
48
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
49
+ output_size = (output_size[0], output_size[0])
 
50
 
51
+ _, image_height, image_width = f.get_dimensions(img)
52
+ crop_height, crop_width = output_size
53
 
54
+ if crop_width > image_width or crop_height > image_height:
55
+ padding_ltrb = [
56
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
57
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
58
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
59
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
60
+ ]
61
+ img = f.pad(img, padding_ltrb, fill=fill)
62
+ _, image_height, image_width = f.get_dimensions(img)
63
+ if crop_width == image_width and crop_height == image_height:
64
+ return img
65
 
66
+ crop_top = int(round((image_height - crop_height) / 2.0))
67
+ crop_left = int(round((image_width - crop_width) / 2.0))
68
+ return f.crop(img, crop_top, crop_left, crop_height, crop_width)
69
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ class _CenterCropOrPad(torch.nn.Module):
72
+ """Crops the given image at the center.
73
+ If the image is torch Tensor, it is expected
74
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
75
+ dimensions. If image size is smaller than output size along any edge, image is
76
+ padded with 0 and then center cropped.
77
 
78
+ Args:
79
+ size (sequence or int): Desired output size of the crop. If size is an
80
+ int instead of sequence like (h, w), a square crop (size, size) is
81
+ made. If provided a sequence of length 1, it will be interpreted as
82
+ (size[0], size[0]).
83
+ """
84
 
85
+ def __init__(self, size, fill=0):
86
+ super().__init__()
87
+ self.size = _setup_size(
88
+ size, error_msg='Please provide only two dimensions (h, w) for size.'
89
+ )
90
+ self.fill = fill
91
 
92
+ def forward(self, img):
93
+ """
94
+ Args:
95
+ img (PIL Image or Tensor): Image to be cropped.
96
 
97
+ Returns:
98
+ PIL Image or Tensor: Cropped image.
99
+ """
100
+ return _center_crop_or_pad(img, self.size, fill=self.fill)
101
 
102
+ def __repr__(self) -> str:
103
+ return f'{self.__class__.__name__}(size={self.size})'
104
 
105
 
106
+ def _convert_to_rgb(image):
107
+ return image.convert('RGB')
108
 
109
+
110
+ class _ResizeKeepRatio:
111
+ """Resize while keeping ratio. Copied from timm"""
112
 
113
  def __init__(
114
  self,
 
163
  ratio_factor[0] / aspect_factor,
164
  ratio_factor[1] * aspect_factor,
165
  )
166
+ return [
167
+ round(x * factor / ratio) for x, factor in zip(source_size, ratio_factor)
168
+ ]
169
 
170
  def __call__(self, img):
171
  """
 
185
  self.random_aspect_prob,
186
  self.random_aspect_range,
187
  )
188
+ img = f.resize(img, size, self.interpolation)
189
  return img
190
 
191
  def __repr__(self):
 
195
  return format_string
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class _ColorJitter(object):
199
+ """Apply color jitter to the PIL image with a specified probability"""
 
 
200
 
201
  def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
202
  assert 0.0 <= p <= 1.0
 
213
 
214
 
215
  class _GrayScale(object):
216
+ """Apply gray scale to the PIL image with a specified probability"""
 
 
217
 
218
  def __init__(self, p=0.2):
219
  assert 0.0 <= p <= 1.0
 
227
  return img
228
 
229
 
230
+ @dataclass
231
+ class AugmentationCfg:
232
+ scale: Tuple[float, float] = (0.9, 1.0)
233
+ ratio: Optional[Tuple[float, float]] = None
234
+ color_jitter: Optional[
235
+ Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
236
+ ] = None
237
+ re_prob: Optional[float] = None
238
+ re_count: Optional[int] = None
239
+ use_timm: bool = False
240
+ color_jitter_prob: float = None
241
+ gray_scale_prob: float = None
242
+
243
+
244
  def image_transform(
245
  image_size: Union[int, Tuple[int, int]],
246
  is_train: bool,
 
340
  else:
341
  if resize_mode == 'longest':
342
  transforms = [
343
+ _ResizeKeepRatio(
344
  image_size, interpolation=interpolation_mode, longest=1
345
  ),
346
+ _CenterCropOrPad(image_size, fill=fill_color),
347
  ]
348
  elif resize_mode == 'squash':
349
  if isinstance(image_size, int):
 
361
  transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
362
  else:
363
  # resize shortest edge to matching target dim for non-square target
364
+ transforms = [_ResizeKeepRatio(image_size)]
365
  transforms += [CenterCrop(image_size)]
366
 
367
  transforms.extend(
 
372
  ]
373
  )
374
  return Compose(transforms)