gmastrapas commited on
Commit
cdebfc7
1 Parent(s): ae96581

fix: various fixes

Browse files
Files changed (5) hide show
  1. configuration_clip.py +4 -0
  2. eva_model.py +2 -1
  3. hf_model.py +67 -24
  4. modeling_clip.py +20 -27
  5. rope_embeddings.py +4 -9
configuration_clip.py CHANGED
@@ -24,6 +24,8 @@ class JinaCLIPTextConfig(PretrainedConfig):
24
  embed_dim: int = 768,
25
  hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
26
  hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
 
 
27
  pooler_type: Optional[str] = None,
28
  proj_type: Optional[str] = None,
29
  proj_bias: bool = False,
@@ -34,6 +36,8 @@ class JinaCLIPTextConfig(PretrainedConfig):
34
  self.embed_dim = embed_dim
35
  self.hf_model_name_or_path = hf_model_name_or_path
36
  self.hf_model_config_kwargs = hf_model_config_kwargs or {}
 
 
37
  self.pooler_type = pooler_type
38
  self.proj_type = proj_type
39
  self.proj_bias = proj_bias
 
24
  embed_dim: int = 768,
25
  hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
26
  hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
27
+ default_instruction_task: Optional[str] = None,
28
+ default_lora_task: Optional[str] = None,
29
  pooler_type: Optional[str] = None,
30
  proj_type: Optional[str] = None,
31
  proj_bias: bool = False,
 
36
  self.embed_dim = embed_dim
37
  self.hf_model_name_or_path = hf_model_name_or_path
38
  self.hf_model_config_kwargs = hf_model_config_kwargs or {}
39
+ self.default_instruction_task = default_instruction_task
40
+ self.default_lora_task = default_lora_task
41
  self.pooler_type = pooler_type
42
  self.proj_type = proj_type
43
  self.proj_bias = proj_bias
eva_model.py CHANGED
@@ -12,7 +12,8 @@ import torch.nn as nn
12
  import torch.nn.functional as f
13
 
14
  try:
15
- from timm.models.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
 
16
  except ImportError or ModuleNotFoundError:
17
  from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
18
 
 
12
  import torch.nn.functional as f
13
 
14
  try:
15
+ from timm.models.layers import drop_path as timm_drop_path
16
+ from timm.models.layers import to_2tuple, trunc_normal_
17
  except ImportError or ModuleNotFoundError:
18
  from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
19
 
hf_model.py CHANGED
@@ -1,5 +1,7 @@
1
  import re
 
2
  from typing import Dict, Optional
 
3
  import torch
4
  import torch.nn as nn
5
  from transformers import AutoConfig, AutoModel, PretrainedConfig
@@ -9,7 +11,6 @@ from transformers.modeling_outputs import (
9
  BaseModelOutputWithPoolingAndCrossAttentions,
10
  )
11
 
12
-
13
  _HF_ARCH_DICT = {
14
  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
15
  'roberta': {
@@ -120,6 +121,8 @@ class HFTextEncoder(nn.Module):
120
  trust_remote_code: bool = False,
121
  revision: Optional[str] = None,
122
  code_revision: Optional[str] = None,
 
 
123
  model_config_kwargs: Optional[Dict] = None,
124
  ):
125
  super().__init__()
@@ -129,39 +132,35 @@ class HFTextEncoder(nn.Module):
129
  model_config_kwargs = model_config_kwargs or {}
130
 
131
  if config is None:
132
- self.config = AutoConfig.from_pretrained(
133
- model_name_or_path,
134
- trust_remote_code=trust_remote_code,
135
- revision=revision,
136
- code_revision=code_revision,
137
- )
138
- self.config.update(model_config_kwargs)
139
- create_func, model_args = (
140
- (AutoModel.from_pretrained, model_name_or_path)
141
- if pretrained
142
- else (AutoModel.from_config, self.config)
143
- )
144
- if (
145
- hasattr(self.config, 'is_encoder_decoder')
146
- and self.config.is_encoder_decoder
147
- ):
148
- self.transformer = create_func(
149
- model_args,
150
  trust_remote_code=trust_remote_code,
151
  revision=revision,
 
152
  code_revision=code_revision,
153
  **model_config_kwargs,
154
  )
155
- self.transformer = self.transformer.encoder
156
  else:
157
- self.transformer = create_func(
158
- model_args,
 
 
 
 
 
 
159
  trust_remote_code=trust_remote_code,
160
- revision=revision,
161
  add_pooling_layer=False,
162
  code_revision=code_revision,
163
- **model_config_kwargs,
164
  )
 
 
 
 
 
 
165
  else:
166
  self.config = config
167
  self.config.update(model_config_kwargs)
@@ -209,6 +208,50 @@ class HFTextEncoder(nn.Module):
209
  self._task_instructions = self.transformer._task_instructions
210
  self._supports_task_instructions = True
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  @torch.jit.ignore
213
  def set_grad_checkpointing(self, _=True):
214
  self.transformer.gradient_checkpointing_enable()
 
1
  import re
2
+ import warnings
3
  from typing import Dict, Optional
4
+
5
  import torch
6
  import torch.nn as nn
7
  from transformers import AutoConfig, AutoModel, PretrainedConfig
 
11
  BaseModelOutputWithPoolingAndCrossAttentions,
12
  )
13
 
 
14
  _HF_ARCH_DICT = {
15
  # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
16
  'roberta': {
 
121
  trust_remote_code: bool = False,
122
  revision: Optional[str] = None,
123
  code_revision: Optional[str] = None,
124
+ default_instruction_task: Optional[str] = None,
125
+ default_lora_task: Optional[str] = None,
126
  model_config_kwargs: Optional[Dict] = None,
127
  ):
128
  super().__init__()
 
132
  model_config_kwargs = model_config_kwargs or {}
133
 
134
  if config is None:
135
+ if pretrained:
136
+ self.transformer = AutoModel.from_pretrained(
137
+ model_name_or_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  trust_remote_code=trust_remote_code,
139
  revision=revision,
140
+ add_pooling_layer=False,
141
  code_revision=code_revision,
142
  **model_config_kwargs,
143
  )
144
+ self.config = self.transformer.config
145
  else:
146
+ self.config = AutoConfig.from_pretrained(
147
+ model_name_or_path,
148
+ trust_remote_code=trust_remote_code,
149
+ code_revision=code_revision,
150
+ )
151
+ self.config.update(model_config_kwargs)
152
+ self.transformer = AutoModel.from_config(
153
+ self.config,
154
  trust_remote_code=trust_remote_code,
 
155
  add_pooling_layer=False,
156
  code_revision=code_revision,
 
157
  )
158
+ if (
159
+ hasattr(self.config, 'is_encoder_decoder')
160
+ and self.config.is_encoder_decoder
161
+ ):
162
+ self.transformer = self.transformer.encoder
163
+
164
  else:
165
  self.config = config
166
  self.config.update(model_config_kwargs)
 
208
  self._task_instructions = self.transformer._task_instructions
209
  self._supports_task_instructions = True
210
 
211
+ self.default_instruction_task = None
212
+ self.default_lora_task = None
213
+ self.default_instruction = None
214
+ self.default_loraid = None
215
+ if default_instruction_task is not None:
216
+ self.default_instruction_task = default_instruction_task
217
+ self.default_instruction = self.get_instruction_from_task(
218
+ default_instruction_task
219
+ )
220
+ if default_lora_task is not None:
221
+ self.default_lora_task = default_lora_task
222
+ self.default_loraid = self.get_loraid_from_task(default_lora_task)
223
+
224
+ def get_instruction_from_task(self, task: str) -> Optional[str]:
225
+ if self._supports_task_instructions:
226
+ if task not in self._task_instructions:
227
+ raise ValueError(
228
+ f'Unsupported task \'{task}\'. Choose one of the following: '
229
+ f'{", ".join(self._task_instructions)} or set to None to disable '
230
+ f'task instructions completely'
231
+ )
232
+ return self._task_instructions[task]
233
+ else:
234
+ warnings.warn(
235
+ 'Model does not support task instructions, ignoring instruction '
236
+ f"task '{task}'"
237
+ )
238
+ return None
239
+
240
+ def get_loraid_from_task(self, task: str) -> Optional[int]:
241
+ if self._supports_lora:
242
+ if task not in self._lora_adaptation_map:
243
+ raise ValueError(
244
+ f'Unsupported task \'{task}\'. Choose one of the following: '
245
+ f'{", ".join(self._task_instructions)} or set to None to disable '
246
+ f'the LoRA adapters completely'
247
+ )
248
+ return self._lora_adaptation_map[task]
249
+ else:
250
+ warnings.warn(
251
+ f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
252
+ )
253
+ return None
254
+
255
  @torch.jit.ignore
256
  def set_grad_checkpointing(self, _=True):
257
  self.transformer.gradient_checkpointing_enable()
modeling_clip.py CHANGED
@@ -68,6 +68,8 @@ def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
68
  return HFTextEncoder(
69
  model_name_or_path=config.hf_model_name_or_path,
70
  output_dim=config.embed_dim,
 
 
71
  pooler_type=config.pooler_type,
72
  proj_type=config.proj_type,
73
  proj_bias=config.proj_bias,
@@ -532,33 +534,25 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
532
 
533
  truncate_dim = truncate_dim or self.config.truncate_dim
534
 
535
- adapter_mask = None
 
536
  if task:
537
- if not self.text_model._supports_lora:
538
- logger.warning('Text tower does not support LoRA task adaptation')
539
- elif task not in self.text_model._lora_adaptation_map:
540
- raise ValueError(
541
- f'Unsupported task \'{task}\'. Choose one of the following: '
542
- f'{", ".join(self.text_model._lora_adaptation_map)} or bypass the '
543
- '`task` argument to disable LoRA completely.'
544
- )
545
- else:
546
- taskid = self.text_model._lora_adaptation_map[task]
547
- nexamples = 1 if isinstance(sentences, str) else len(sentences)
548
- adapter_mask = torch.full(
549
- (nexamples,), taskid, dtype=torch.int32, device=self.device
550
- )
551
- if not self.text_model._supports_task_instructions:
552
- logger.warning('Text tower does not support task instructions')
553
- elif task not in self.text_model._task_instructions:
554
- raise ValueError(
555
- f'Unsupported task \'{task}\'. Choose one of the following: '
556
- f'{", ".join(self.text_model._task_instructions)} or bypass the '
557
- '`task` argument to disable task instructions completely.'
558
- )
559
- else:
560
- instruction = self.text_model._task_instructions[task]
561
- sentences = [instruction + sentence for sentence in sentences]
562
 
563
  for i in range_iter:
564
  tokens = self.tokenizer(
@@ -566,7 +560,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
566
  return_tensors='pt',
567
  **tokenizer_kwargs,
568
  ).to(self.device)
569
-
570
  embeddings = self.get_text_features(
571
  input_ids=tokens, adapter_mask=adapter_mask
572
  )
 
68
  return HFTextEncoder(
69
  model_name_or_path=config.hf_model_name_or_path,
70
  output_dim=config.embed_dim,
71
+ default_instruction_task=config.default_instruction_task,
72
+ default_lora_task=config.default_lora_task,
73
  pooler_type=config.pooler_type,
74
  proj_type=config.proj_type,
75
  proj_bias=config.proj_bias,
 
534
 
535
  truncate_dim = truncate_dim or self.config.truncate_dim
536
 
537
+ instruction = self.text_model.default_instruction
538
+ loraid = self.text_model.default_loraid
539
  if task:
540
+ _selected_instruction = self.text_model.get_instruction_from_task(task)
541
+ if _selected_instruction is not None:
542
+ instruction = _selected_instruction
543
+ _selected_loraid = self.text_model.get_loraid_from_task(task)
544
+ if _selected_loraid is not None:
545
+ loraid = _selected_loraid
546
+
547
+ if instruction is not None:
548
+ sentences = [instruction + sentence for sentence in sentences]
549
+
550
+ adapter_mask = None
551
+ if loraid is not None:
552
+ nexamples = 1 if isinstance(sentences, str) else len(sentences)
553
+ adapter_mask = torch.full(
554
+ (nexamples,), loraid, dtype=torch.int32, device=self.device
555
+ )
 
 
 
 
 
 
 
 
 
556
 
557
  for i in range_iter:
558
  tokens = self.tokenizer(
 
560
  return_tensors='pt',
561
  **tokenizer_kwargs,
562
  ).to(self.device)
 
563
  embeddings = self.get_text_features(
564
  input_ids=tokens, adapter_mask=adapter_mask
565
  )
rope_embeddings.py CHANGED
@@ -3,7 +3,6 @@
3
  # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
  # --------------------------------------------------------
5
 
6
- import logging
7
  from math import pi
8
 
9
  import torch
@@ -75,10 +74,8 @@ class VisionRotaryEmbedding(nn.Module):
75
 
76
  freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
77
 
78
- self.register_buffer('freqs_cos', freqs.cos())
79
- self.register_buffer('freqs_sin', freqs.sin())
80
-
81
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
82
 
83
  def forward(self, t, start_index=0):
84
  rot_dim = self.freqs_cos.shape[-1]
@@ -137,10 +134,8 @@ class VisionRotaryEmbeddingFast(nn.Module):
137
 
138
  self.patch_dropout = patch_dropout
139
 
140
- self.register_buffer('freqs_cos', freqs_cos)
141
- self.register_buffer('freqs_sin', freqs_sin)
142
-
143
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
144
 
145
  def forward(self, t, patch_indices_keep=None):
146
  if patch_indices_keep is not None:
 
3
  # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
  # --------------------------------------------------------
5
 
 
6
  from math import pi
7
 
8
  import torch
 
74
 
75
  freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
76
 
77
+ self.register_buffer('freqs_cos', freqs.cos(), persistent=False)
78
+ self.register_buffer('freqs_sin', freqs.sin(), persistent=False)
 
 
79
 
80
  def forward(self, t, start_index=0):
81
  rot_dim = self.freqs_cos.shape[-1]
 
134
 
135
  self.patch_dropout = patch_dropout
136
 
137
+ self.register_buffer('freqs_cos', freqs_cos, persistent=False)
138
+ self.register_buffer('freqs_sin', freqs_sin, persistent=False)
 
 
139
 
140
  def forward(self, t, patch_indices_keep=None):
141
  if patch_indices_keep is not None: