gmastrapas
commited on
Commit
•
d779277
1
Parent(s):
a4480ad
feat: remove adapter_mask from interface
Browse files- hf_model.py +76 -23
- modeling_clip.py +10 -39
hf_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import re
|
2 |
import warnings
|
3 |
-
from typing import Dict, Optional
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
@@ -208,21 +208,48 @@ class HFTextEncoder(nn.Module):
|
|
208 |
self._task_instructions = self.transformer._task_instructions
|
209 |
self._supports_task_instructions = True
|
210 |
|
211 |
-
self.
|
212 |
-
self.
|
213 |
-
self.
|
214 |
-
self.
|
|
|
215 |
if default_instruction_task is not None:
|
216 |
-
self.
|
217 |
-
self.
|
218 |
default_instruction_task
|
219 |
)
|
220 |
if default_lora_task is not None:
|
221 |
-
self.
|
222 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
if self._supports_task_instructions:
|
|
|
|
|
226 |
if task not in self._task_instructions:
|
227 |
raise ValueError(
|
228 |
f'Unsupported task \'{task}\'. Choose one of the following: '
|
@@ -231,14 +258,17 @@ class HFTextEncoder(nn.Module):
|
|
231 |
)
|
232 |
return self._task_instructions[task]
|
233 |
else:
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
238 |
return None
|
239 |
|
240 |
-
def get_loraid_from_task(self, task: str) -> Optional[int]:
|
241 |
if self._supports_lora:
|
|
|
|
|
242 |
if task not in self._lora_adaptation_map:
|
243 |
raise ValueError(
|
244 |
f'Unsupported task \'{task}\'. Choose one of the following: '
|
@@ -247,11 +277,18 @@ class HFTextEncoder(nn.Module):
|
|
247 |
)
|
248 |
return self._lora_adaptation_map[task]
|
249 |
else:
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
253 |
return None
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
@torch.jit.ignore
|
256 |
def set_grad_checkpointing(self, _=True):
|
257 |
self.transformer.gradient_checkpointing_enable()
|
@@ -260,12 +297,28 @@ class HFTextEncoder(nn.Module):
|
|
260 |
pass
|
261 |
|
262 |
def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
if adapter_mask is not None:
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
269 |
projected = self.proj(pooled_out)
|
270 |
seqlen = out.last_hidden_state.shape[1]
|
271 |
tokens = (
|
|
|
1 |
import re
|
2 |
import warnings
|
3 |
+
from typing import Dict, Optional, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
208 |
self._task_instructions = self.transformer._task_instructions
|
209 |
self._supports_task_instructions = True
|
210 |
|
211 |
+
self._default_instruction_task = None
|
212 |
+
self._default_lora_task = None
|
213 |
+
self._default_instruction = None
|
214 |
+
self._default_loraid = None
|
215 |
+
|
216 |
if default_instruction_task is not None:
|
217 |
+
self._default_instruction_task = default_instruction_task
|
218 |
+
self._default_instruction = self.get_instruction_from_task(
|
219 |
default_instruction_task
|
220 |
)
|
221 |
if default_lora_task is not None:
|
222 |
+
self._default_lora_task = default_lora_task
|
223 |
+
self._default_loraid = self.get_loraid_from_task(default_lora_task)
|
224 |
+
|
225 |
+
@property
|
226 |
+
def supports_task_instructions(self) -> bool:
|
227 |
+
return self._supports_task_instructions
|
228 |
+
|
229 |
+
@property
|
230 |
+
def supports_lora(self) -> bool:
|
231 |
+
return self._supports_lora
|
232 |
+
|
233 |
+
@property
|
234 |
+
def task_instructions(self) -> Dict[str, str]:
|
235 |
+
return self._task_instructions
|
236 |
+
|
237 |
+
@property
|
238 |
+
def lora_adaptation_map(self) -> Dict[str, int]:
|
239 |
+
return self._lora_adaptation_map
|
240 |
|
241 |
+
@property
|
242 |
+
def default_instruction(self) -> Optional[str]:
|
243 |
+
return self._default_instruction
|
244 |
+
|
245 |
+
@property
|
246 |
+
def default_loraid(self) -> Optional[int]:
|
247 |
+
return self._default_loraid
|
248 |
+
|
249 |
+
def get_instruction_from_task(self, task: Optional[str]) -> Optional[str]:
|
250 |
if self._supports_task_instructions:
|
251 |
+
if task is None:
|
252 |
+
return self._default_instruction
|
253 |
if task not in self._task_instructions:
|
254 |
raise ValueError(
|
255 |
f'Unsupported task \'{task}\'. Choose one of the following: '
|
|
|
258 |
)
|
259 |
return self._task_instructions[task]
|
260 |
else:
|
261 |
+
if task is not None:
|
262 |
+
warnings.warn(
|
263 |
+
'Model does not support task instructions, ignoring instruction '
|
264 |
+
f"task '{task}'"
|
265 |
+
)
|
266 |
return None
|
267 |
|
268 |
+
def get_loraid_from_task(self, task: Optional[str]) -> Optional[int]:
|
269 |
if self._supports_lora:
|
270 |
+
if task is None:
|
271 |
+
return self._default_loraid
|
272 |
if task not in self._lora_adaptation_map:
|
273 |
raise ValueError(
|
274 |
f'Unsupported task \'{task}\'. Choose one of the following: '
|
|
|
277 |
)
|
278 |
return self._lora_adaptation_map[task]
|
279 |
else:
|
280 |
+
if task is not None:
|
281 |
+
warnings.warn(
|
282 |
+
f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
|
283 |
+
)
|
284 |
return None
|
285 |
|
286 |
+
@staticmethod
|
287 |
+
def get_adapter_mask_from_loraid(
|
288 |
+
batch_size: int, loraid: int, device: Union[str, torch.device]
|
289 |
+
):
|
290 |
+
return torch.full((batch_size,), loraid, dtype=torch.int32, device=device)
|
291 |
+
|
292 |
@torch.jit.ignore
|
293 |
def set_grad_checkpointing(self, _=True):
|
294 |
self.transformer.gradient_checkpointing_enable()
|
|
|
297 |
pass
|
298 |
|
299 |
def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
|
300 |
+
if adapter_mask is None:
|
301 |
+
default_loraid = self.default_loraid
|
302 |
+
if default_loraid is not None:
|
303 |
+
adapter_mask = self.get_adapter_mask_from_loraid(
|
304 |
+
x.shape[0], default_loraid, x.device
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
if not self.supports_lora:
|
308 |
+
warnings.warn(
|
309 |
+
'Model does not support LoRA adapters, setting adapter_mask to None'
|
310 |
+
)
|
311 |
+
adapter_mask = None
|
312 |
+
|
313 |
+
attention_mask = (x != self.config.pad_token_id).long()
|
314 |
+
lora_kwargs = {}
|
315 |
if adapter_mask is not None:
|
316 |
+
lora_kwargs['adapter_mask'] = adapter_mask
|
317 |
+
|
318 |
+
out = self.transformer(
|
319 |
+
input_ids=x, attention_mask=attention_mask, **lora_kwargs
|
320 |
+
)
|
321 |
+
pooled_out = self.pooler(out, attention_mask)
|
322 |
projected = self.proj(pooled_out)
|
323 |
seqlen = out.last_hidden_state.shape[1]
|
324 |
tokens = (
|
modeling_clip.py
CHANGED
@@ -159,9 +159,6 @@ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
|
|
159 |
self,
|
160 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
161 |
return_dict: Optional[bool] = None,
|
162 |
-
use_lora: bool = False,
|
163 |
-
adapter_mask: Optional[torch.Tensor] = None,
|
164 |
-
task: Optional[str] = None,
|
165 |
*_,
|
166 |
**__,
|
167 |
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
|
@@ -169,12 +166,7 @@ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
|
|
169 |
return_dict if return_dict is not None else self.config.use_return_dict
|
170 |
)
|
171 |
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
|
172 |
-
feats = self.text_model(
|
173 |
-
x=x,
|
174 |
-
use_lora=use_lora,
|
175 |
-
adapter_mask=adapter_mask,
|
176 |
-
task=task,
|
177 |
-
)
|
178 |
out = CLIPTextModelOutput(text_embeds=feats)
|
179 |
return out if return_dict else out.to_tuple()
|
180 |
|
@@ -277,12 +269,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
277 |
def get_text_features(
|
278 |
self,
|
279 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
280 |
-
adapter_mask: Optional[torch.Tensor] = None,
|
281 |
*_,
|
282 |
**__,
|
283 |
) -> torch.FloatTensor:
|
284 |
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
|
285 |
-
return self.text_projection(self.text_model(x=x
|
286 |
|
287 |
def get_image_features(
|
288 |
self,
|
@@ -461,9 +452,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
461 |
sentences(`str` or `List[str]`):
|
462 |
Sentence or sentences to be encoded
|
463 |
task(`str`, *optional*, defaults to `None`):
|
464 |
-
Specifies the task for which the encoding is intended. If `task` is
|
465 |
-
|
466 |
-
|
467 |
batch_size(`int`, *optional*, defaults to 32):
|
468 |
Batch size for the computation
|
469 |
show_progress_bar(`bool`, *optional*, defaults to None):
|
@@ -534,35 +525,17 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
534 |
|
535 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
536 |
|
537 |
-
instruction = self.text_model.
|
538 |
-
|
539 |
-
if task:
|
540 |
-
_selected_instruction = self.text_model.get_instruction_from_task(task)
|
541 |
-
if _selected_instruction is not None:
|
542 |
-
instruction = _selected_instruction
|
543 |
-
_selected_loraid = self.text_model.get_loraid_from_task(task)
|
544 |
-
if _selected_loraid is not None:
|
545 |
-
loraid = _selected_loraid
|
546 |
-
|
547 |
-
if instruction is not None:
|
548 |
sentences = [instruction + sentence for sentence in sentences]
|
549 |
|
550 |
-
adapter_mask = None
|
551 |
-
if loraid is not None:
|
552 |
-
nexamples = 1 if isinstance(sentences, str) else len(sentences)
|
553 |
-
adapter_mask = torch.full(
|
554 |
-
(nexamples,), loraid, dtype=torch.int32, device=self.device
|
555 |
-
)
|
556 |
-
|
557 |
for i in range_iter:
|
558 |
tokens = self.tokenizer(
|
559 |
sentences[i: i + batch_size],
|
560 |
return_tensors='pt',
|
561 |
**tokenizer_kwargs,
|
562 |
).to(self.device)
|
563 |
-
embeddings = self.get_text_features(
|
564 |
-
input_ids=tokens, adapter_mask=adapter_mask
|
565 |
-
)
|
566 |
if truncate_dim:
|
567 |
embeddings = self.truncate_embeddings(embeddings, truncate_dim)
|
568 |
if normalize_embeddings:
|
@@ -589,7 +562,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
589 |
self,
|
590 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
591 |
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
|
592 |
-
adapter_mask: Optional[torch.Tensor] = None,
|
593 |
return_dict: Optional[bool] = None,
|
594 |
return_loss: Optional[bool] = None,
|
595 |
*_,
|
@@ -599,9 +571,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
599 |
return_dict if return_dict is not None else self.config.use_return_dict
|
600 |
)
|
601 |
image_embeds = self.get_image_features(pixel_values=pixel_values)
|
602 |
-
text_embeds = self.get_text_features(
|
603 |
-
|
604 |
-
)
|
605 |
# normalized features
|
606 |
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
607 |
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
|
159 |
self,
|
160 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
161 |
return_dict: Optional[bool] = None,
|
|
|
|
|
|
|
162 |
*_,
|
163 |
**__,
|
164 |
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
|
|
|
166 |
return_dict if return_dict is not None else self.config.use_return_dict
|
167 |
)
|
168 |
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
|
169 |
+
feats = self.text_model(x=x)
|
|
|
|
|
|
|
|
|
|
|
170 |
out = CLIPTextModelOutput(text_embeds=feats)
|
171 |
return out if return_dict else out.to_tuple()
|
172 |
|
|
|
269 |
def get_text_features(
|
270 |
self,
|
271 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
|
|
272 |
*_,
|
273 |
**__,
|
274 |
) -> torch.FloatTensor:
|
275 |
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
|
276 |
+
return self.text_projection(self.text_model(x=x))
|
277 |
|
278 |
def get_image_features(
|
279 |
self,
|
|
|
452 |
sentences(`str` or `List[str]`):
|
453 |
Sentence or sentences to be encoded
|
454 |
task(`str`, *optional*, defaults to `None`):
|
455 |
+
Specifies the task for which the encoding is intended. If a `task` is
|
456 |
+
provided, a task-specific instruction is added to the beginning of each
|
457 |
+
sentence. If `task` is not provided, no instructions are added.
|
458 |
batch_size(`int`, *optional*, defaults to 32):
|
459 |
Batch size for the computation
|
460 |
show_progress_bar(`bool`, *optional*, defaults to None):
|
|
|
525 |
|
526 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
527 |
|
528 |
+
instruction = self.text_model.get_instruction_from_task(task)
|
529 |
+
if instruction:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
sentences = [instruction + sentence for sentence in sentences]
|
531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
for i in range_iter:
|
533 |
tokens = self.tokenizer(
|
534 |
sentences[i: i + batch_size],
|
535 |
return_tensors='pt',
|
536 |
**tokenizer_kwargs,
|
537 |
).to(self.device)
|
538 |
+
embeddings = self.get_text_features(input_ids=tokens)
|
|
|
|
|
539 |
if truncate_dim:
|
540 |
embeddings = self.truncate_embeddings(embeddings, truncate_dim)
|
541 |
if normalize_embeddings:
|
|
|
562 |
self,
|
563 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
564 |
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
|
|
|
565 |
return_dict: Optional[bool] = None,
|
566 |
return_loss: Optional[bool] = None,
|
567 |
*_,
|
|
|
571 |
return_dict if return_dict is not None else self.config.use_return_dict
|
572 |
)
|
573 |
image_embeds = self.get_image_features(pixel_values=pixel_values)
|
574 |
+
text_embeds = self.get_text_features(input_ids=input_ids)
|
575 |
+
|
|
|
576 |
# normalized features
|
577 |
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
578 |
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|