bwang0911 commited on
Commit
b3163bd
1 Parent(s): a91c3ec

refactor: refine load images

Browse files
Files changed (1) hide show
  1. modeling_clip.py +76 -13
modeling_clip.py CHANGED
@@ -223,6 +223,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
223
  self.text_projection = nn.Identity()
224
 
225
  self.tokenizer = None
 
226
  self.post_init()
227
 
228
  def get_text_features(
@@ -249,7 +250,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
249
 
250
  def get_tokenizer(self):
251
  if not self.tokenizer:
252
- self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
253
  return self.tokenizer
254
 
255
  @torch.inference_mode()
@@ -264,7 +265,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
264
  device: Optional[torch.device] = None,
265
  normalize_embeddings: bool = False,
266
  **tokenizer_kwargs,
267
- ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]::
268
  """
269
  Computes sentence embeddings
270
  Args:
@@ -373,19 +374,81 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
373
  self.train(is_training)
374
  return all_embeddings
375
 
 
 
 
 
 
 
 
376
  def encode_image(
377
  self,
378
- pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
379
- return_dict: Optional[bool] = None,
380
- *_,
381
- **__,
382
- ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
383
- return_dict = (
384
- return_dict if return_dict is not None else self.config.use_return_dict
385
- )
386
- feats = self.get_image_features(pixel_values=pixel_values)
387
- out = CLIPVisionModelOutput(image_embeds=feats)
388
- return out if return_dict else out.to_tuple()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  def forward(
391
  self,
 
223
  self.text_projection = nn.Identity()
224
 
225
  self.tokenizer = None
226
+ self.preprocess = None
227
  self.post_init()
228
 
229
  def get_text_features(
 
250
 
251
  def get_tokenizer(self):
252
  if not self.tokenizer:
253
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
254
  return self.tokenizer
255
 
256
  @torch.inference_mode()
 
265
  device: Optional[torch.device] = None,
266
  normalize_embeddings: bool = False,
267
  **tokenizer_kwargs,
268
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
269
  """
270
  Computes sentence embeddings
271
  Args:
 
374
  self.train(is_training)
375
  return all_embeddings
376
 
377
+ def get_preprocess(self):
378
+ if not self.preprocess:
379
+ self.preprocess = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)
380
+ return self.preprocess
381
+
382
+
383
+ @torch.inference_mode()
384
  def encode_image(
385
  self,
386
+ images: Union[str, List[str]],
387
+ batch_size: int = 32,
388
+ show_progress_bar: Optional[bool] = None,
389
+ convert_to_numpy: bool = True,
390
+ convert_to_tensor: bool = False,
391
+ device: Optional[torch.device] = None,
392
+ normalize_embeddings: bool = False,
393
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
394
+ from PIL.Image import Image
395
+
396
+ is_training = self.training
397
+ self.eval()
398
+
399
+ self.preprocess = self.get_preprocess()
400
+
401
+ if show_progress_bar is None:
402
+ show_progress_bar = (
403
+ logger.getEffectiveLevel() == logging.INFO
404
+ or logger.getEffectiveLevel() == logging.DEBUG
405
+ )
406
+
407
+ if convert_to_tensor:
408
+ convert_to_numpy = False
409
+
410
+ input_was_single_img = False
411
+ if isinstance(images, str) or not hasattr(images, '__len__'):
412
+ images = [images]
413
+ input_was_single_img = True
414
+
415
+ if device is not None:
416
+ self.to(device)
417
+
418
+ permutation = np.argsort([-len(i) for i in images])
419
+ inverse_permutation = np.argsort(permutation)
420
+ images = [images[idx] for idx in permutation]
421
+
422
+ if has_tqdm:
423
+ range_iter = trange(
424
+ 0,
425
+ len(sentences),
426
+ batch_size,
427
+ desc="Encoding",
428
+ disable=not show_progress_bar,
429
+ )
430
+ else:
431
+ range_iter = range(0, len(sentences), batch_size)
432
+
433
+ for i in range_iter:
434
+ processed_inputs = self.process([Image.open(image) for image in images])
435
+ embeddings = self.get_image_features(processed_inputs)
436
+ if convert_to_numpy:
437
+ embeddings = embeddings.cpu()
438
+ all_embeddings.extend(embeddings)
439
+
440
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
441
+
442
+ if convert_to_tensor:
443
+ all_embeddings = torch.stack(all_embeddings)
444
+ elif convert_to_numpy:
445
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
446
+
447
+ if input_was_single_img:
448
+ all_embeddings = all_embeddings[0]
449
+
450
+ self.train(is_training)
451
+ return all_embeddings
452
 
453
  def forward(
454
  self,