import json import random import spaces import gradio as gr import matplotlib.pyplot as plt import numpy as np import onnxruntime import torch import torchvision.transforms.functional as F from huggingface_hub import hf_hub_download from PIL import Image, ImageColor from torchvision.io import read_image from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks # Load pre-trained model transformations. weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() def fix_category_id(cat_ids: list): # Define the excluded category ids and the remaining ones excluded_indices = {2, 12, 16, 19, 20} remaining_categories = list(set(range(27)) - excluded_indices) # Create a dictionary that maps new IDs to old(original) IDs new_id_to_org_id = dict(zip(range(len(remaining_categories)), remaining_categories)) return [new_id_to_org_id[i-1]+1 for i in cat_ids] def process_categories() -> tuple: """ Load and process category information from a JSON file. Returns a tuple containing two dictionaries: `category_id_to_name` maps category IDs to their names, and `category_id_to_color` maps category IDs to a randomly sampled RGB color. Returns: tuple: A tuple containing two dictionaries: - `category_id_to_name`: a dictionary mapping category IDs to their names. - `category_id_to_color`: a dictionary mapping category IDs to a randomly sampled RGB color. """ # Load raw categories from JSON file with open("categories.json") as fp: categories = json.load(fp) # Map category IDs to names category_id_to_name = {d["id"]: d["name"] for d in categories} # Set the seed for the random sampling operation random.seed(42) # Get a list of all the color names in the PIL colormap color_names = list(ImageColor.colormap.keys()) # Sample 46 unique colors from the list of color names sampled_colors = random.sample(color_names, 46) # Convert the color names to RGB values rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors] # Map category IDs to colors category_id_to_color = { category["id"]: color for category, color in zip(categories, rgb_colors) } return category_id_to_name, category_id_to_color def draw_predictions( boxes, labels, scores, masks, img, model_name, score_threshold, proba_threshold ): """ Draw predictions on the input image based on the provided boxes, labels, scores, and masks. Only predictions with scores above the `score_threshold` will be included, and masks with probabilities exceeding the `proba_threshold` will be displayed. Args: - boxes: numpy.ndarray - an array of bounding box coordinates. - labels: numpy.ndarray - an array of integers representing the predicted class for each bounding box. - scores: numpy.ndarray - an array of confidence scores for each bounding box. - masks: numpy.ndarray - an array of binary masks for each bounding box. - img: PIL.Image.Image - the input image. - model_name: str - name of the model given by the dropdown menu, either "facere" or "facere+". - score_threshold: float - a confidence score threshold for filtering out low-scoring bbox predictions. - proba_threshold: float - a threshold for filtering out low-probability (pixel-wise) mask predictions. Returns: - A list of strings, each representing the path to an image file containing the input image with a different set of predictions drawn (masks, bounding boxes, masks with bounding box labels and scores). """ imgs_list = [] # Map label IDs to names and colors label_id_to_name, label_id_to_color = process_categories() # Filter out predictions using thresholds labels_id = labels[scores > score_threshold].tolist() if model_name == "facere+": labels_id = fix_category_id(labels_id) # models output is in range: [1,class_id+1], hence re-map to: [0,class_id] labels = [label_id_to_name[int(i) - 1] for i in labels_id] masks = (masks[scores > score_threshold] > proba_threshold).astype(np.uint8) boxes = boxes[scores > score_threshold] # Draw masks to input image and save img_masks = draw_segmentation_masks( image=img, masks=torch.from_numpy(masks.squeeze(1).astype(bool)), alpha=0.9, colors=[label_id_to_color[int(i) - 1] for i in labels_id], ) img_masks = F.to_pil_image(img_masks) img_masks.save("img_masks.png") imgs_list.append("img_masks.png") # Draw bboxes to input image and save img_bbox = draw_bounding_boxes(img, boxes=torch.from_numpy(boxes), width=4) img_bbox = F.to_pil_image(img_bbox) img_bbox.save("img_bbox.png") imgs_list.append("img_bbox.png") # Save masks with their bbox labels & bbox scores for col, (mask, label, score) in enumerate(zip(masks, labels, scores)): mask = Image.fromarray(mask.squeeze()) plt.imshow(mask) plt.axis("off") plt.title(f"{label}: {score:.2f}", fontsize=9) plt.savefig(f"mask-{col}.png") plt.close() imgs_list.append(f"mask-{col}.png") return imgs_list @spaces.GPU(duration=20) def inference(image, model_name, mask_threshold, bbox_threshold): """ Load the ONNX model and run inference with the provided input `image`. Visualize the predictions and save them in a figure, which will be shown in the Gradio app. """ # Load image. img = read_image(image) # Apply original transformation to the image. img_transformed = transforms(img) # Download model path_onnx = hf_hub_download( repo_id="rizavelioglu/fashionfail", filename="facere_plus.onnx" if model_name == "facere+" else "facere_base.onnx" ) # Session options (see https://github.com/microsoft/onnxruntime/issues/14694#issuecomment-1598429295) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL # Create an inference session. ort_session = onnxruntime.InferenceSession( path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options, ) # compute ONNX Runtime output prediction ort_inputs = { ort_session.get_inputs()[0].name: img_transformed.unsqueeze(dim=0).numpy() } ort_outs = ort_session.run(None, ort_inputs) boxes, labels, scores, masks = ort_outs imgs_list = draw_predictions(boxes, labels, scores, masks, img, model_name, score_threshold=bbox_threshold, proba_threshold=mask_threshold ) return imgs_list title = "Facere - Demo" description = r"""This is the demo of the paper FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation.
Upload your image and choose the model for inference from the dropdown menu—either `Facere` or `Facere+`
Check out the project page for more information.""" article = r""" Example images are sampled from the `Fashionpedia-test` and `FashionFail-test` set, which the models did not see during training.
**Citation**
If you find our work useful in your research, please consider giving a star ⭐ and a citation: ``` @inproceedings{velioglu2024fashionfail, author = {Velioglu, Riza and Chan, Robin and Hammer, Barbara}, title = {FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation}, journal = {IJCNN}, eprint = {2404.08582}, year = {2024}, } ``` """ examples = [ ["examples/0a4f8205a3b58e70eec99fbbb9422d08.jpg", "facere", 0.5, 0.7], ["examples/0a72e0f76ab9b75945f5d610508f9336.jpg", "facere", 0.5, 0.7], ["examples/0a939e0e67011aecf7195c17ecb9733c.jpg", "facere", 0.5, 0.7], ["examples/adi_9086_5.jpg", "facere", 0.5, 0.5], ["examples/adi_9086_5.jpg", "facere+", 0.5, 0.5], ["examples/adi_9704_1.jpg", "facere", 0.5, 0.5], ["examples/adi_9704_1.jpg", "facere+", 0.5, 0.5], ["examples/adi_10266_5.jpg", "facere", 0.5, 0.5], ["examples/adi_10266_5.jpg", "facere+", 0.5, 0.5], ["examples/adi_103_6.jpg", "facere", 0.5, 0.5], ["examples/adi_103_6.jpg", "facere+", 0.5, 0.5], ["examples/adi_1201_2.jpg", "facere", 0.5, 0.7], ["examples/adi_1201_2.jpg", "facere+", 0.5, 0.7], ["examples/adi_2149_5.jpg", "facere", 0.5, 0.7], ["examples/adi_2149_5.jpg", "facere+", 0.5, 0.7], ["examples/adi_5476_3.jpg", "facere", 0.5, 0.7], ["examples/adi_5476_3.jpg", "facere+", 0.5, 0.7], ["examples/adi_5641_4.jpg", "facere", 0.5, 0.7], ["examples/adi_5641_4.jpg", "facere+", 0.5, 0.7] ] demo = gr.Interface( fn=inference, inputs=[ gr.Image(type="filepath", label="input"), gr.Dropdown(["facere", "facere+"], value="facere", label="Models"), gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="Mask threshold", info="a threshold for " "filtering out " "low-probability (" "pixel-wise) mask " "predictions"), gr.Slider(value=0.7, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold", info="a threshold for " "filtering out " "low-scoring bbox " "predictions") ], outputs=gr.Gallery(label="output", preview=True, height=500), title=title, description=description, article=article, examples=examples, cache_examples=True, examples_per_page=6 ) if __name__ == "__main__": demo.launch()