# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import os from io import BytesIO from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import requests from packaging import version from .utils import ( ExplicitEnum, is_jax_tensor, is_tf_tensor, is_torch_available, is_torch_tensor, is_vision_available, logging, requires_backends, to_numpy, ) from .utils.constants import ( # noqa: F401 IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ) if is_vision_available(): import PIL.Image import PIL.ImageOps if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PILImageResampling = PIL.Image.Resampling else: PILImageResampling = PIL.Image if TYPE_CHECKING: if is_torch_available(): import torch logger = logging.get_logger(__name__) ImageInput = Union[ "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"] ] # noqa class ChannelDimension(ExplicitEnum): FIRST = "channels_first" LAST = "channels_last" class AnnotationFormat(ExplicitEnum): COCO_DETECTION = "coco_detection" COCO_PANOPTIC = "coco_panoptic" class AnnotionFormat(ExplicitEnum): COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value AnnotationType = Dict[str, Union[int, str, List[Dict]]] def is_pil_image(img): return is_vision_available() and isinstance(img, PIL.Image.Image) def is_valid_image(img): return ( (is_vision_available() and isinstance(img, PIL.Image.Image)) or isinstance(img, np.ndarray) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img) ) def valid_images(imgs): # If we have an list of images, make sure every image is valid if isinstance(imgs, (list, tuple)): for img in imgs: if not valid_images(img): return False # If not a list of tuple, we have been given a single image or batched tensor of images elif not is_valid_image(imgs): return False return True def is_batched(img): if isinstance(img, (list, tuple)): return is_valid_image(img[0]) return False def is_scaled_image(image: np.ndarray) -> bool: """ Checks to see whether the pixel values have already been rescaled to [0, 1]. """ if image.dtype == np.uint8: return False # It's possible the image has pixel values in [0, 255] but is of floating type return np.min(image) >= 0 and np.max(image) <= 1 def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: """ Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1. If the input is a batch of images, it is converted to a list of images. Args: images (`ImageInput`): Image of images to turn into a list of images. expected_ndims (`int`, *optional*, defaults to 3): Expected number of dimensions for a single input image. If the input image has a different number of dimensions, an error is raised. """ if is_batched(images): return images # Either the input is a single image, in which case we create a list of length 1 if isinstance(images, PIL.Image.Image): # PIL images are never batched return [images] if is_valid_image(images): if images.ndim == expected_ndims + 1: # Batch of images images = list(images) elif images.ndim == expected_ndims: # Single image images = [images] else: raise ValueError( f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got" f" {images.ndim} dimensions." ) return images raise ValueError( "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " f"jax.ndarray, but got {type(images)}." ) def to_numpy_array(img) -> np.ndarray: if not is_valid_image(img): raise ValueError(f"Invalid image type: {type(img)}") if is_vision_available() and isinstance(img, PIL.Image.Image): return np.array(img) return to_numpy(img) def infer_channel_dimension_format( image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None ) -> ChannelDimension: """ Infers the channel dimension format of `image`. Args: image (`np.ndarray`): The image to infer the channel dimension of. num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`): The number of channels of the image. Returns: The channel dimension of the image. """ num_channels = num_channels if num_channels is not None else (1, 3) num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels if image.ndim == 3: first_dim, last_dim = 0, 2 elif image.ndim == 4: first_dim, last_dim = 1, 3 else: raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") if image.shape[first_dim] in num_channels: return ChannelDimension.FIRST elif image.shape[last_dim] in num_channels: return ChannelDimension.LAST raise ValueError("Unable to infer channel dimension format") def get_channel_dimension_axis( image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None ) -> int: """ Returns the channel dimension axis of the image. Args: image (`np.ndarray`): The image to get the channel dimension axis of. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the image. If `None`, will infer the channel dimension from the image. Returns: The channel dimension axis of the image. """ if input_data_format is None: input_data_format = infer_channel_dimension_format(image) if input_data_format == ChannelDimension.FIRST: return image.ndim - 3 elif input_data_format == ChannelDimension.LAST: return image.ndim - 1 raise ValueError(f"Unsupported data format: {input_data_format}") def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: """ Returns the (height, width) dimensions of the image. Args: image (`np.ndarray`): The image to get the dimensions of. channel_dim (`ChannelDimension`, *optional*): Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image. Returns: A tuple of the image's height and width. """ if channel_dim is None: channel_dim = infer_channel_dimension_format(image) if channel_dim == ChannelDimension.FIRST: return image.shape[-2], image.shape[-1] elif channel_dim == ChannelDimension.LAST: return image.shape[-3], image.shape[-2] else: raise ValueError(f"Unsupported data format: {channel_dim}") def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool: if ( isinstance(annotation, dict) and "image_id" in annotation and "annotations" in annotation and isinstance(annotation["annotations"], (list, tuple)) and ( # an image can have no annotations len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict) ) ): return True return False def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool: if ( isinstance(annotation, dict) and "image_id" in annotation and "segments_info" in annotation and "file_name" in annotation and isinstance(annotation["segments_info"], (list, tuple)) and ( # an image can have no segments len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict) ) ): return True return False def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool: return all(is_valid_annotation_coco_detection(ann) for ann in annotations) def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool: return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations) def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image": """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. timeout (`float`, *optional*): The timeout value in seconds for the URL request. Returns: `PIL.Image.Image`: A PIL Image. """ requires_backends(load_image, ["vision"]) if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file # like http_huggingface_co.png image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content)) elif os.path.isfile(image): image = PIL.Image.open(image) else: if image.startswith("data:image/"): image = image.split(",")[1] # Try to load as base64 try: b64 = base64.b64decode(image, validate=True) image = PIL.Image.open(BytesIO(b64)) except Exception as e: raise ValueError( f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image def validate_preprocess_arguments( do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_pad: Optional[bool] = None, size_divisibility: Optional[int] = None, do_center_crop: Optional[bool] = None, crop_size: Optional[Dict[str, int]] = None, do_resize: Optional[bool] = None, size: Optional[Dict[str, int]] = None, resample: Optional["PILImageResampling"] = None, ): """ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method. Raises `ValueError` if arguments incompatibility is caught. Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`, sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow existing arguments when possible. """ if do_rescale and rescale_factor is None: raise ValueError("rescale_factor must be specified if do_rescale is True.") if do_pad and size_divisibility is None: # Here, size_divisor might be passed as the value of size raise ValueError( "Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True." ) if do_normalize and (image_mean is None or image_std is None): raise ValueError("image_mean and image_std must both be specified if do_normalize is True.") if do_center_crop and crop_size is None: raise ValueError("crop_size must be specified if do_center_crop is True.") if do_resize and (size is None or resample is None): raise ValueError("size and resample must be specified if do_resize is True.") # In the future we can add a TF implementation here when we have TF models. class ImageFeatureExtractionMixin: """ Mixin that contain utilities for preparing image features. """ def _ensure_format_supported(self, image): if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image): raise ValueError( f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and " "`torch.Tensor` are." ) def to_pil_image(self, image, rescale=None): """ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if needed. Args: image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): The image to convert to the PIL Image format. rescale (`bool`, *optional*): Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default to `True` if the image type is a floating type, `False` otherwise. """ self._ensure_format_supported(image) if is_torch_tensor(image): image = image.numpy() if isinstance(image, np.ndarray): if rescale is None: # rescale default to the array being of floating type. rescale = isinstance(image.flat[0], np.floating) # If the channel as been moved to first dim, we put it back at the end. if image.ndim == 3 and image.shape[0] in [1, 3]: image = image.transpose(1, 2, 0) if rescale: image = image * 255 image = image.astype(np.uint8) return PIL.Image.fromarray(image) return image def convert_rgb(self, image): """ Converts `PIL.Image.Image` to RGB format. Args: image (`PIL.Image.Image`): The image to convert. """ self._ensure_format_supported(image) if not isinstance(image, PIL.Image.Image): return image return image.convert("RGB") def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: """ Rescale a numpy image by scale amount """ self._ensure_format_supported(image) return image * scale def to_numpy_array(self, image, rescale=None, channel_first=True): """ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first dimension. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image to convert to a NumPy array. rescale (`bool`, *optional*): Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise. channel_first (`bool`, *optional*, defaults to `True`): Whether or not to permute the dimensions of the image to put the channel dimension first. """ self._ensure_format_supported(image) if isinstance(image, PIL.Image.Image): image = np.array(image) if is_torch_tensor(image): image = image.numpy() rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale if rescale: image = self.rescale(image.astype(np.float32), 1 / 255.0) if channel_first and image.ndim == 3: image = image.transpose(2, 0, 1) return image def expand_dims(self, image): """ Expands 2-dimensional `image` to 3 dimensions. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image to expand. """ self._ensure_format_supported(image) # Do nothing if PIL image if isinstance(image, PIL.Image.Image): return image if is_torch_tensor(image): image = image.unsqueeze(0) else: image = np.expand_dims(image, axis=0) return image def normalize(self, image, mean, std, rescale=False): """ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array if it's a PIL Image. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image to normalize. mean (`List[float]` or `np.ndarray` or `torch.Tensor`): The mean (per channel) to use for normalization. std (`List[float]` or `np.ndarray` or `torch.Tensor`): The standard deviation (per channel) to use for normalization. rescale (`bool`, *optional*, defaults to `False`): Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will happen automatically. """ self._ensure_format_supported(image) if isinstance(image, PIL.Image.Image): image = self.to_numpy_array(image, rescale=True) # If the input image is a PIL image, it automatically gets rescaled. If it's another # type it may need rescaling. elif rescale: if isinstance(image, np.ndarray): image = self.rescale(image.astype(np.float32), 1 / 255.0) elif is_torch_tensor(image): image = self.rescale(image.float(), 1 / 255.0) if isinstance(image, np.ndarray): if not isinstance(mean, np.ndarray): mean = np.array(mean).astype(image.dtype) if not isinstance(std, np.ndarray): std = np.array(std).astype(image.dtype) elif is_torch_tensor(image): import torch if not isinstance(mean, torch.Tensor): mean = torch.tensor(mean) if not isinstance(std, torch.Tensor): std = torch.tensor(std) if image.ndim == 3 and image.shape[0] in [1, 3]: return (image - mean[:, None, None]) / std[:, None, None] else: return (image - mean) / std def resize(self, image, size, resample=None, default_to_square=True, max_size=None): """ Resizes `image`. Enforces conversion of input to PIL.Image. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image to resize. size (`int` or `Tuple[int, int]`): The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to this. If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): The filter to user for resampling. default_to_square (`bool`, *optional*, defaults to `True`): How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square (`size`,`size`). If set to `False`, will replicate [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) with support for resizing only the smallest edge and providing an optional `max_size`. max_size (`int`, *optional*, defaults to `None`): The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater than `max_size` after being resized according to `size`, then the image is resized again so that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter than `size`. Only used if `default_to_square` is `False`. Returns: image: A resized `PIL.Image.Image`. """ resample = resample if resample is not None else PILImageResampling.BILINEAR self._ensure_format_supported(image) if not isinstance(image, PIL.Image.Image): image = self.to_pil_image(image) if isinstance(size, list): size = tuple(size) if isinstance(size, int) or len(size) == 1: if default_to_square: size = (size, size) if isinstance(size, int) else (size[0], size[0]) else: width, height = image.size # specified size only for the smallest edge short, long = (width, height) if width <= height else (height, width) requested_new_short = size if isinstance(size, int) else size[0] if short == requested_new_short: return image new_short, new_long = requested_new_short, int(requested_new_short * long / short) if max_size is not None: if max_size <= requested_new_short: raise ValueError( f"max_size = {max_size} must be strictly greater than the requested " f"size for the smaller edge size = {size}" ) if new_long > max_size: new_short, new_long = int(max_size * new_short / new_long), max_size size = (new_short, new_long) if width <= height else (new_long, new_short) return image.resize(size, resample=resample) def center_crop(self, image, size): """ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the size given, it will be padded (so the returned result has the size asked). Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)): The image to resize. size (`int` or `Tuple[int, int]`): The size to which crop the image. Returns: new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels, height, width). """ self._ensure_format_supported(image) if not isinstance(size, tuple): size = (size, size) # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width) if is_torch_tensor(image) or isinstance(image, np.ndarray): if image.ndim == 2: image = self.expand_dims(image) image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2] else: image_shape = (image.size[1], image.size[0]) top = (image_shape[0] - size[0]) // 2 bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. left = (image_shape[1] - size[1]) // 2 right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. # For PIL Images we have a method to crop directly. if isinstance(image, PIL.Image.Image): return image.crop((left, top, right, bottom)) # Check if image is in (n_channels, height, width) or (height, width, n_channels) format channel_first = True if image.shape[0] in [1, 3] else False # Transpose (height, width, n_channels) format images if not channel_first: if isinstance(image, np.ndarray): image = image.transpose(2, 0, 1) if is_torch_tensor(image): image = image.permute(2, 0, 1) # Check if cropped area is within image boundaries if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]: return image[..., top:bottom, left:right] # Otherwise, we may need to pad if the image is too small. Oh joy... new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1])) if isinstance(image, np.ndarray): new_image = np.zeros_like(image, shape=new_shape) elif is_torch_tensor(image): new_image = image.new_zeros(new_shape) top_pad = (new_shape[-2] - image_shape[0]) // 2 bottom_pad = top_pad + image_shape[0] left_pad = (new_shape[-1] - image_shape[1]) // 2 right_pad = left_pad + image_shape[1] new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image top += top_pad bottom += top_pad left += left_pad right += left_pad new_image = new_image[ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right) ] return new_image def flip_channel_order(self, image): """ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of `image` to a NumPy array if it's a PIL Image. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should be first. """ self._ensure_format_supported(image) if isinstance(image, PIL.Image.Image): image = self.to_numpy_array(image) return image[::-1, :, :] def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None): """ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees counter clockwise around its centre. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before rotating. Returns: image: A rotated `PIL.Image.Image`. """ resample = resample if resample is not None else PIL.Image.NEAREST self._ensure_format_supported(image) if not isinstance(image, PIL.Image.Image): image = self.to_pil_image(image) return image.rotate( angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor ) def promote_annotation_format(annotation_format: Union[AnnotionFormat, AnnotationFormat]) -> AnnotationFormat: # can be removed when `AnnotionFormat` is fully deprecated return AnnotationFormat(annotation_format.value) def validate_annotations( annotation_format: AnnotationFormat, supported_annotation_formats: Tuple[AnnotationFormat, ...], annotations: List[Dict], ) -> None: if isinstance(annotation_format, AnnotionFormat): logger.warning_once( f"`{annotation_format.__class__.__name__}` is deprecated and will be removed in v4.38. " f"Please use `{AnnotationFormat.__name__}` instead." ) annotation_format = promote_annotation_format(annotation_format) if annotation_format not in supported_annotation_formats: raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}") if annotation_format is AnnotationFormat.COCO_DETECTION: if not valid_coco_detection_annotations(annotations): raise ValueError( "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts " "(batch of images) with the following keys: `image_id` and `annotations`, with the latter " "being a list of annotations in the COCO format." ) if annotation_format is AnnotationFormat.COCO_PANOPTIC: if not valid_coco_panoptic_annotations(annotations): raise ValueError( "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts " "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with " "the latter being a list of annotations in the COCO format." ) def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]): unused_keys = set(captured_kwargs).difference(set(valid_processor_keys)) if unused_keys: unused_key_str = ", ".join(unused_keys) # TODO raise a warning here instead of simply logging? logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")