from typing import Dict, List, Optional, Tuple, Union, Iterable import numpy as np import torch import transformers from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( ChannelDimension, get_resize_output_image_size, rescale, resize, to_channel_dimension_format, ) from transformers.image_utils import ( ImageInput, PILImageResampling, infer_channel_dimension_format, get_channel_dimension_axis, make_list_of_images, to_numpy_array, valid_images, ) from transformers.utils import is_torch_tensor class FaceSegformerImageProcessor(BaseImageProcessor): def __init__(self, **kwargs): super().__init__(**kwargs) self.image_size = kwargs.get("image_size", (224, 224)) self.normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406]) self.normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225]) self.resample = kwargs.get("resample", PILImageResampling.BILINEAR) self.data_format = kwargs.get("data_format", ChannelDimension.FIRST) @staticmethod def normalize( image: np.ndarray, mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]], max_pixel_value: float = 255.0, data_format: Optional[ChannelDimension] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """ Copied from: https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/image_transforms.py#L209 BUT uses the formula from albumentations: https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize img = (img - mean * max_pixel_value) / (std * max_pixel_value) """ if not isinstance(image, np.ndarray): raise ValueError("image must be a numpy array") if input_data_format is None: input_data_format = infer_channel_dimension_format(image) channel_axis = get_channel_dimension_axis( image, input_data_format=input_data_format ) num_channels = image.shape[channel_axis] # We cast to float32 to avoid errors that can occur when subtracting uint8 values. # We preserve the original dtype if it is a float type to prevent upcasting float16. if not np.issubdtype(image.dtype, np.floating): image = image.astype(np.float32) if isinstance(mean, Iterable): if len(mean) != num_channels: raise ValueError( f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}" ) else: mean = [mean] * num_channels mean = np.array(mean, dtype=image.dtype) if isinstance(std, Iterable): if len(std) != num_channels: raise ValueError( f"std must have {num_channels} elements if it is an iterable, got {len(std)}" ) else: std = [std] * num_channels std = np.array(std, dtype=image.dtype) # Uses max_pixel_value for normalization if input_data_format == ChannelDimension.LAST: image = (image - mean * max_pixel_value) / (std * max_pixel_value) else: image = ((image.T - mean * max_pixel_value) / (std * max_pixel_value)).T image = ( to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image ) return image def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Copied from: https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py """ default_to_square = True if "shortest_edge" in size: size = size["shortest_edge"] default_to_square = False elif "height" in size and "width" in size: size = (size["height"], size["width"]) else: raise ValueError( "Size must contain either 'shortest_edge' or 'height' and 'width'." ) output_size = get_resize_output_image_size( image, size=size, default_to_square=default_to_square, input_data_format=input_data_format, ) return resize( image, size=output_size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs, ) def __call__(self, images: ImageInput, masks: ImageInput = None, **kwargs): """ Adapted from: https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py """ # single to iterable if needed images = make_list_of_images(images) # validate if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) # make numpy arrays images = [to_numpy_array(image) for image in images] # get channel dimensions input_data_format = kwargs.get("input_data_format") if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) # check if training # todo: can also assume if masks are passed that we are doing training? if kwargs.get("do_training", False) is True: if mask is None: raise ValueError("must pass masks if doing training.") # todo: implement this soon. raise NotImplementedError("not yet implemented.") # Assume we want to do all transformations for training else: # do transformations for inference... images = [ self.resize( image=image, size={"height": self.image_size[0], "width": self.image_size[1]}, resample=kwargs.get("resample") or self.resample, input_data_format=input_data_format, ) for image in images ] images = [ self.normalize( image=image, mean=kwargs.get("normalize_mean") or self.normalize_mean, std=kwargs.get("normalize_std") or self.normalize_std, input_data_format=input_data_format, ) for image in images ] # fix dimensions images = [ to_channel_dimension_format( image, kwargs.get("data_format") or self.data_format, input_channel_dim=input_data_format, ) for image in images ] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type="pt") # Copied from transformers.models.segformer.image_processing_segformer.SegformerImageProcessor.post_process_semantic_segmentation def post_process_semantic_segmentation( self, outputs, target_sizes: List[Tuple] = None ): """ Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. Args: outputs ([`SegformerForSemanticSegmentation`]): Raw outputs of the model. target_sizes (`List[Tuple]` of length `batch_size`, *optional*): List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, predictions will not be resized. Returns: semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ # TODO: add support for other frameworks logits = outputs.logits # Resize logits and compute semantic segmentation maps if target_sizes is not None: if len(logits) != len(target_sizes): raise ValueError( "Make sure that you pass in as many target sizes as the batch dimension of the logits" ) if is_torch_tensor(target_sizes): target_sizes = target_sizes.numpy() semantic_segmentation = [] for idx in range(len(logits)): resized_logits = torch.nn.functional.interpolate( logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False, ) semantic_map = resized_logits[0].argmax(dim=0) semantic_segmentation.append(semantic_map) else: semantic_segmentation = logits.argmax(dim=1) semantic_segmentation = [ semantic_segmentation[i] for i in range(semantic_segmentation.shape[0]) ] return semantic_segmentation