Spaces:
Runtime error
Runtime error
File size: 2,470 Bytes
cacafc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from pathlib import Path
import PIL
import torch
import torchvision.transforms.functional as FT
from torch.utils.data import Dataset
from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize
from torchvision.transforms import InterpolationMode
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
def _convert_image_to_rgb(image):
return image.convert("RGB")
def collate_fn(batch):
'''
function which discard None images in a batch when using torch DataLoader
:param batch: input_batch
:return: output_batch = input_batch - None_values
'''
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
class TargetPad:
"""
If an image aspect ratio is above a target ratio, pad the image to match such target ratio.
For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022).
"""
def __init__(self, target_ratio: float, size: int):
"""
:param target_ratio: target ratio
:param size: preprocessing output dimension
"""
self.size = size
self.target_ratio = target_ratio
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
actual_ratio = max(w, h) / min(w, h)
if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
return image
scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
hp = max(int((scaled_max_wh - w) / 2), 0)
vp = max(int((scaled_max_wh - h) / 2), 0)
padding = [hp, vp, hp, vp]
return FT.pad(image, padding, 0, 'constant')
def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor:
"""
CLIP-like preprocessing transform computed after using TargetPad pad
:param target_ratio: target ratio for TargetPad
:param dim: image output dimension
:return: CLIP-like torchvision Compose transform
"""
return Compose([
TargetPad(target_ratio, dim),
Resize(dim, interpolation=InterpolationMode.BICUBIC),
CenterCrop(dim),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
|