hanxiao commited on
Commit
79a1cae
1 Parent(s): cd1adcb

fix: encode_image and encode_text

Browse files

- support bfloat16
- support image loader from url, PIL.image and data:image/ strings
- fix bug and optimize encode_image efficiency

Files changed (1) hide show
  1. modeling_clip.py +51 -24
modeling_clip.py CHANGED
@@ -6,7 +6,10 @@
6
 
7
  from functools import partial
8
  from typing import List, Optional, Tuple, Union
9
-
 
 
 
10
  import numpy as np
11
  import torch
12
  import torch.nn.functional as f
@@ -373,7 +376,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
373
  if convert_to_tensor:
374
  all_embeddings = torch.stack(all_embeddings)
375
  elif convert_to_numpy:
376
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
377
 
378
  if input_was_string:
379
  all_embeddings = all_embeddings[0]
@@ -381,10 +384,15 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
381
  self.train(is_training)
382
  return all_embeddings
383
 
 
 
 
 
 
384
  @torch.inference_mode()
385
  def encode_image(
386
  self,
387
- images: Union[str, List[str]],
388
  batch_size: int = 32,
389
  show_progress_bar: Optional[bool] = None,
390
  convert_to_numpy: bool = True,
@@ -394,10 +402,10 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
394
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
395
  """
396
  Computes image embeddings.
397
-
398
  Args:
399
- images(`str` or `List[str]`):
400
- image or images paths to be encoded
401
  batch_size(`int`, *optional*, defaults to 32):
402
  Batch size for the computation
403
  show_progress_bar(`bool`, *optional*, defaults to None):
@@ -421,35 +429,34 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
421
  If convert_to_tensor, a stacked tensor is returned.
422
  If convert_to_numpy, a numpy matrix is returned.
423
  """
424
- from PIL import Image
425
-
426
  is_training = self.training
427
  self.eval()
428
-
429
  self.preprocess = self.get_preprocess()
430
  all_embeddings = []
431
-
432
  if show_progress_bar is None:
433
  show_progress_bar = (
434
  logger.getEffectiveLevel() == logging.INFO
435
  or logger.getEffectiveLevel() == logging.DEBUG
436
  )
437
-
438
  if convert_to_tensor:
439
  convert_to_numpy = False
440
-
441
  input_was_single_img = False
442
  if isinstance(images, str) or not hasattr(images, '__len__'):
443
  images = [images]
444
  input_was_single_img = True
445
-
446
  if device is not None:
447
  self.to(device)
448
-
449
- permutation = np.argsort([-len(i) for i in images])
450
  inverse_permutation = np.argsort(permutation)
451
  images = [images[idx] for idx in permutation]
452
-
453
  if has_tqdm:
454
  range_iter = trange(
455
  0,
@@ -460,26 +467,46 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
460
  )
461
  else:
462
  range_iter = range(0, len(images), batch_size)
463
-
464
- for _ in range_iter:
465
- processed_inputs = self.preprocess([Image.open(image) for image in images])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  embeddings = self.get_image_features(processed_inputs)
 
467
  if normalize_embeddings:
468
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
469
  if convert_to_numpy:
470
  embeddings = embeddings.cpu()
471
  all_embeddings.extend(embeddings)
472
-
473
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
474
-
475
  if convert_to_tensor:
476
  all_embeddings = torch.stack(all_embeddings)
477
  elif convert_to_numpy:
478
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
479
-
480
  if input_was_single_img:
481
  all_embeddings = all_embeddings[0]
482
-
483
  self.train(is_training)
484
  return all_embeddings
485
 
 
6
 
7
  from functools import partial
8
  from typing import List, Optional, Tuple, Union
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ import requests
12
+ import base64
13
  import numpy as np
14
  import torch
15
  import torch.nn.functional as f
 
376
  if convert_to_tensor:
377
  all_embeddings = torch.stack(all_embeddings)
378
  elif convert_to_numpy:
379
+ all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
380
 
381
  if input_was_string:
382
  all_embeddings = all_embeddings[0]
 
384
  self.train(is_training)
385
  return all_embeddings
386
 
387
+ def decode_data_image(data_image_str):
388
+ header, data = data_image_str.split(',', 1)
389
+ image_data = base64.b64decode(data)
390
+ return Image.open(BytesIO(image_data))
391
+
392
  @torch.inference_mode()
393
  def encode_image(
394
  self,
395
+ images: Union[str, List[Union[str, Image.Image]]],
396
  batch_size: int = 32,
397
  show_progress_bar: Optional[bool] = None,
398
  convert_to_numpy: bool = True,
 
402
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
403
  """
404
  Computes image embeddings.
405
+
406
  Args:
407
+ images(`str` or `List[Union[str, Image.Image]]`):
408
+ image paths, URLs, PIL images, or data:image/ strings to be encoded
409
  batch_size(`int`, *optional*, defaults to 32):
410
  Batch size for the computation
411
  show_progress_bar(`bool`, *optional*, defaults to None):
 
429
  If convert_to_tensor, a stacked tensor is returned.
430
  If convert_to_numpy, a numpy matrix is returned.
431
  """
432
+
 
433
  is_training = self.training
434
  self.eval()
435
+
436
  self.preprocess = self.get_preprocess()
437
  all_embeddings = []
438
+
439
  if show_progress_bar is None:
440
  show_progress_bar = (
441
  logger.getEffectiveLevel() == logging.INFO
442
  or logger.getEffectiveLevel() == logging.DEBUG
443
  )
444
+
445
  if convert_to_tensor:
446
  convert_to_numpy = False
447
+
448
  input_was_single_img = False
449
  if isinstance(images, str) or not hasattr(images, '__len__'):
450
  images = [images]
451
  input_was_single_img = True
452
+
453
  if device is not None:
454
  self.to(device)
455
+
456
+ permutation = np.argsort([-len(str(i)) for i in images])
457
  inverse_permutation = np.argsort(permutation)
458
  images = [images[idx] for idx in permutation]
459
+
460
  if has_tqdm:
461
  range_iter = trange(
462
  0,
 
467
  )
468
  else:
469
  range_iter = range(0, len(images), batch_size)
470
+
471
+ for i in range_iter:
472
+ batch_images = images[i:i+batch_size]
473
+ processed_inputs = []
474
+
475
+ for img in batch_images:
476
+ if isinstance(img, str):
477
+ if img.startswith('http'):
478
+ response = requests.get(img)
479
+ image = Image.open(BytesIO(response.content)).convert('RGB')
480
+ elif img.startswith('data:image/'):
481
+ image = decode_data_image(img).convert('RGB')
482
+ else:
483
+ image = Image.open(img).convert('RGB')
484
+ elif isinstance(img, Image.Image):
485
+ image = img.convert('RGB')
486
+ else:
487
+ raise ValueError("Unsupported image format")
488
+
489
+ processed_inputs.append(self.preprocess(image))
490
+
491
+ processed_inputs = torch.stack(processed_inputs).to(device)
492
  embeddings = self.get_image_features(processed_inputs)
493
+
494
  if normalize_embeddings:
495
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
496
  if convert_to_numpy:
497
  embeddings = embeddings.cpu()
498
  all_embeddings.extend(embeddings)
499
+
500
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
501
+
502
  if convert_to_tensor:
503
  all_embeddings = torch.stack(all_embeddings)
504
  elif convert_to_numpy:
505
+ all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
506
+
507
  if input_was_single_img:
508
  all_embeddings = all_embeddings[0]
509
+
510
  self.train(is_training)
511
  return all_embeddings
512