fix: encode_image and encode_text
Browse files- support bfloat16
- support image loader from url, PIL.image and data:image/ strings
- fix bug and optimize encode_image efficiency
- modeling_clip.py +51 -24
modeling_clip.py
CHANGED
@@ -6,7 +6,10 @@
|
|
6 |
|
7 |
from functools import partial
|
8 |
from typing import List, Optional, Tuple, Union
|
9 |
-
|
|
|
|
|
|
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
import torch.nn.functional as f
|
@@ -373,7 +376,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
373 |
if convert_to_tensor:
|
374 |
all_embeddings = torch.stack(all_embeddings)
|
375 |
elif convert_to_numpy:
|
376 |
-
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
377 |
|
378 |
if input_was_string:
|
379 |
all_embeddings = all_embeddings[0]
|
@@ -381,10 +384,15 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
381 |
self.train(is_training)
|
382 |
return all_embeddings
|
383 |
|
|
|
|
|
|
|
|
|
|
|
384 |
@torch.inference_mode()
|
385 |
def encode_image(
|
386 |
self,
|
387 |
-
images: Union[str, List[str]],
|
388 |
batch_size: int = 32,
|
389 |
show_progress_bar: Optional[bool] = None,
|
390 |
convert_to_numpy: bool = True,
|
@@ -394,10 +402,10 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
394 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
395 |
"""
|
396 |
Computes image embeddings.
|
397 |
-
|
398 |
Args:
|
399 |
-
images(`str` or `List[str]`):
|
400 |
-
image
|
401 |
batch_size(`int`, *optional*, defaults to 32):
|
402 |
Batch size for the computation
|
403 |
show_progress_bar(`bool`, *optional*, defaults to None):
|
@@ -421,35 +429,34 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
421 |
If convert_to_tensor, a stacked tensor is returned.
|
422 |
If convert_to_numpy, a numpy matrix is returned.
|
423 |
"""
|
424 |
-
|
425 |
-
|
426 |
is_training = self.training
|
427 |
self.eval()
|
428 |
-
|
429 |
self.preprocess = self.get_preprocess()
|
430 |
all_embeddings = []
|
431 |
-
|
432 |
if show_progress_bar is None:
|
433 |
show_progress_bar = (
|
434 |
logger.getEffectiveLevel() == logging.INFO
|
435 |
or logger.getEffectiveLevel() == logging.DEBUG
|
436 |
)
|
437 |
-
|
438 |
if convert_to_tensor:
|
439 |
convert_to_numpy = False
|
440 |
-
|
441 |
input_was_single_img = False
|
442 |
if isinstance(images, str) or not hasattr(images, '__len__'):
|
443 |
images = [images]
|
444 |
input_was_single_img = True
|
445 |
-
|
446 |
if device is not None:
|
447 |
self.to(device)
|
448 |
-
|
449 |
-
permutation = np.argsort([-len(i) for i in images])
|
450 |
inverse_permutation = np.argsort(permutation)
|
451 |
images = [images[idx] for idx in permutation]
|
452 |
-
|
453 |
if has_tqdm:
|
454 |
range_iter = trange(
|
455 |
0,
|
@@ -460,26 +467,46 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
460 |
)
|
461 |
else:
|
462 |
range_iter = range(0, len(images), batch_size)
|
463 |
-
|
464 |
-
for
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
embeddings = self.get_image_features(processed_inputs)
|
|
|
467 |
if normalize_embeddings:
|
468 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
469 |
if convert_to_numpy:
|
470 |
embeddings = embeddings.cpu()
|
471 |
all_embeddings.extend(embeddings)
|
472 |
-
|
473 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
474 |
-
|
475 |
if convert_to_tensor:
|
476 |
all_embeddings = torch.stack(all_embeddings)
|
477 |
elif convert_to_numpy:
|
478 |
-
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
479 |
-
|
480 |
if input_was_single_img:
|
481 |
all_embeddings = all_embeddings[0]
|
482 |
-
|
483 |
self.train(is_training)
|
484 |
return all_embeddings
|
485 |
|
|
|
6 |
|
7 |
from functools import partial
|
8 |
from typing import List, Optional, Tuple, Union
|
9 |
+
from PIL import Image
|
10 |
+
from io import BytesIO
|
11 |
+
import requests
|
12 |
+
import base64
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
import torch.nn.functional as f
|
|
|
376 |
if convert_to_tensor:
|
377 |
all_embeddings = torch.stack(all_embeddings)
|
378 |
elif convert_to_numpy:
|
379 |
+
all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
|
380 |
|
381 |
if input_was_string:
|
382 |
all_embeddings = all_embeddings[0]
|
|
|
384 |
self.train(is_training)
|
385 |
return all_embeddings
|
386 |
|
387 |
+
def decode_data_image(data_image_str):
|
388 |
+
header, data = data_image_str.split(',', 1)
|
389 |
+
image_data = base64.b64decode(data)
|
390 |
+
return Image.open(BytesIO(image_data))
|
391 |
+
|
392 |
@torch.inference_mode()
|
393 |
def encode_image(
|
394 |
self,
|
395 |
+
images: Union[str, List[Union[str, Image.Image]]],
|
396 |
batch_size: int = 32,
|
397 |
show_progress_bar: Optional[bool] = None,
|
398 |
convert_to_numpy: bool = True,
|
|
|
402 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
403 |
"""
|
404 |
Computes image embeddings.
|
405 |
+
|
406 |
Args:
|
407 |
+
images(`str` or `List[Union[str, Image.Image]]`):
|
408 |
+
image paths, URLs, PIL images, or data:image/ strings to be encoded
|
409 |
batch_size(`int`, *optional*, defaults to 32):
|
410 |
Batch size for the computation
|
411 |
show_progress_bar(`bool`, *optional*, defaults to None):
|
|
|
429 |
If convert_to_tensor, a stacked tensor is returned.
|
430 |
If convert_to_numpy, a numpy matrix is returned.
|
431 |
"""
|
432 |
+
|
|
|
433 |
is_training = self.training
|
434 |
self.eval()
|
435 |
+
|
436 |
self.preprocess = self.get_preprocess()
|
437 |
all_embeddings = []
|
438 |
+
|
439 |
if show_progress_bar is None:
|
440 |
show_progress_bar = (
|
441 |
logger.getEffectiveLevel() == logging.INFO
|
442 |
or logger.getEffectiveLevel() == logging.DEBUG
|
443 |
)
|
444 |
+
|
445 |
if convert_to_tensor:
|
446 |
convert_to_numpy = False
|
447 |
+
|
448 |
input_was_single_img = False
|
449 |
if isinstance(images, str) or not hasattr(images, '__len__'):
|
450 |
images = [images]
|
451 |
input_was_single_img = True
|
452 |
+
|
453 |
if device is not None:
|
454 |
self.to(device)
|
455 |
+
|
456 |
+
permutation = np.argsort([-len(str(i)) for i in images])
|
457 |
inverse_permutation = np.argsort(permutation)
|
458 |
images = [images[idx] for idx in permutation]
|
459 |
+
|
460 |
if has_tqdm:
|
461 |
range_iter = trange(
|
462 |
0,
|
|
|
467 |
)
|
468 |
else:
|
469 |
range_iter = range(0, len(images), batch_size)
|
470 |
+
|
471 |
+
for i in range_iter:
|
472 |
+
batch_images = images[i:i+batch_size]
|
473 |
+
processed_inputs = []
|
474 |
+
|
475 |
+
for img in batch_images:
|
476 |
+
if isinstance(img, str):
|
477 |
+
if img.startswith('http'):
|
478 |
+
response = requests.get(img)
|
479 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
480 |
+
elif img.startswith('data:image/'):
|
481 |
+
image = decode_data_image(img).convert('RGB')
|
482 |
+
else:
|
483 |
+
image = Image.open(img).convert('RGB')
|
484 |
+
elif isinstance(img, Image.Image):
|
485 |
+
image = img.convert('RGB')
|
486 |
+
else:
|
487 |
+
raise ValueError("Unsupported image format")
|
488 |
+
|
489 |
+
processed_inputs.append(self.preprocess(image))
|
490 |
+
|
491 |
+
processed_inputs = torch.stack(processed_inputs).to(device)
|
492 |
embeddings = self.get_image_features(processed_inputs)
|
493 |
+
|
494 |
if normalize_embeddings:
|
495 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
496 |
if convert_to_numpy:
|
497 |
embeddings = embeddings.cpu()
|
498 |
all_embeddings.extend(embeddings)
|
499 |
+
|
500 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
501 |
+
|
502 |
if convert_to_tensor:
|
503 |
all_embeddings = torch.stack(all_embeddings)
|
504 |
elif convert_to_numpy:
|
505 |
+
all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
|
506 |
+
|
507 |
if input_was_single_img:
|
508 |
all_embeddings = all_embeddings[0]
|
509 |
+
|
510 |
self.train(is_training)
|
511 |
return all_embeddings
|
512 |
|