fashionfail / app.py
rizavelioglu's picture
Update app.py
35ed39f verified
raw
history blame
10.4 kB
import json
import random
import spaces
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime
import torch
import torchvision.transforms.functional as F
from huggingface_hub import hf_hub_download
from PIL import Image, ImageColor
from torchvision.io import read_image
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
# Load pre-trained model transformations.
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()
def fix_category_id(cat_ids: list):
# Define the excluded category ids and the remaining ones
excluded_indices = {2, 12, 16, 19, 20}
remaining_categories = list(set(range(27)) - excluded_indices)
# Create a dictionary that maps new IDs to old(original) IDs
new_id_to_org_id = dict(zip(range(len(remaining_categories)), remaining_categories))
return [new_id_to_org_id[i-1]+1 for i in cat_ids]
def process_categories() -> tuple:
"""
Load and process category information from a JSON file.
Returns a tuple containing two dictionaries: `category_id_to_name` maps category IDs to their names, and
`category_id_to_color` maps category IDs to a randomly sampled RGB color.
Returns:
tuple: A tuple containing two dictionaries:
- `category_id_to_name`: a dictionary mapping category IDs to their names.
- `category_id_to_color`: a dictionary mapping category IDs to a randomly sampled RGB color.
"""
# Load raw categories from JSON file
with open("categories.json") as fp:
categories = json.load(fp)
# Map category IDs to names
category_id_to_name = {d["id"]: d["name"] for d in categories}
# Set the seed for the random sampling operation
random.seed(42)
# Get a list of all the color names in the PIL colormap
color_names = list(ImageColor.colormap.keys())
# Sample 46 unique colors from the list of color names
sampled_colors = random.sample(color_names, 46)
# Convert the color names to RGB values
rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]
# Map category IDs to colors
category_id_to_color = {
category["id"]: color for category, color in zip(categories, rgb_colors)
}
return category_id_to_name, category_id_to_color
def draw_predictions(
boxes, labels, scores, masks, img, model_name, score_threshold, proba_threshold
):
"""
Draw predictions on the input image based on the provided boxes, labels, scores, and masks. Only predictions
with scores above the `score_threshold` will be included, and masks with probabilities exceeding the
`proba_threshold` will be displayed.
Args:
- boxes: numpy.ndarray - an array of bounding box coordinates.
- labels: numpy.ndarray - an array of integers representing the predicted class for each bounding box.
- scores: numpy.ndarray - an array of confidence scores for each bounding box.
- masks: numpy.ndarray - an array of binary masks for each bounding box.
- img: PIL.Image.Image - the input image.
- model_name: str - name of the model given by the dropdown menu, either "facere" or "facere+".
- score_threshold: float - a confidence score threshold for filtering out low-scoring bbox predictions.
- proba_threshold: float - a threshold for filtering out low-probability (pixel-wise) mask predictions.
Returns:
- A list of strings, each representing the path to an image file containing the input image with a different
set of predictions drawn (masks, bounding boxes, masks with bounding box labels and scores).
"""
imgs_list = []
# Map label IDs to names and colors
label_id_to_name, label_id_to_color = process_categories()
# Filter out predictions using thresholds
labels_id = labels[scores > score_threshold].tolist()
if model_name == "facere+":
labels_id = fix_category_id(labels_id)
# models output is in range: [1,class_id+1], hence re-map to: [0,class_id]
labels = [label_id_to_name[int(i) - 1] for i in labels_id]
masks = (masks[scores > score_threshold] > proba_threshold).astype(np.uint8)
boxes = boxes[scores > score_threshold]
# Draw masks to input image and save
img_masks = draw_segmentation_masks(
image=img,
masks=torch.from_numpy(masks.squeeze(1).astype(bool)),
alpha=0.9,
colors=[label_id_to_color[int(i) - 1] for i in labels_id],
)
img_masks = F.to_pil_image(img_masks)
img_masks.save("img_masks.png")
imgs_list.append("img_masks.png")
# Draw bboxes to input image and save
img_bbox = draw_bounding_boxes(img, boxes=torch.from_numpy(boxes), width=4)
img_bbox = F.to_pil_image(img_bbox)
img_bbox.save("img_bbox.png")
imgs_list.append("img_bbox.png")
# Save masks with their bbox labels & bbox scores
for col, (mask, label, score) in enumerate(zip(masks, labels, scores)):
mask = Image.fromarray(mask.squeeze())
plt.imshow(mask)
plt.axis("off")
plt.title(f"{label}: {score:.2f}", fontsize=9)
plt.savefig(f"mask-{col}.png")
plt.close()
imgs_list.append(f"mask-{col}.png")
return imgs_list
@spaces.GPU(duration=20)
def inference(image, model_name, mask_threshold, bbox_threshold):
"""
Load the ONNX model and run inference with the provided input `image`. Visualize the predictions and save them in a
figure, which will be shown in the Gradio app.
"""
# Load image.
img = read_image(image)
# Apply original transformation to the image.
img_transformed = transforms(img)
# Download model
path_onnx = hf_hub_download(
repo_id="rizavelioglu/fashionfail",
filename="facere_plus.onnx" if model_name == "facere+" else "facere_base.onnx"
)
# Session options (see https://github.com/microsoft/onnxruntime/issues/14694#issuecomment-1598429295)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
# Create an inference session.
ort_session = onnxruntime.InferenceSession(
path_onnx,
providers=["CPUExecutionProvider"],
sess_options=sess_options,
)
# compute ONNX Runtime output prediction
ort_inputs = {
ort_session.get_inputs()[0].name: img_transformed.unsqueeze(dim=0).numpy()
}
ort_outs = ort_session.run(None, ort_inputs)
boxes, labels, scores, masks = ort_outs
imgs_list = draw_predictions(boxes, labels, scores, masks, img, model_name,
score_threshold=bbox_threshold, proba_threshold=mask_threshold
)
return imgs_list
title = "Facere - Demo"
description = r"""This is the demo of the paper <a href="https://arxiv.org/abs/2404.08582">FashionFail: Addressing
Failure Cases in Fashion Object Detection and Segmentation</a>. <br>Upload your image and choose the model for inference
from the dropdown menu—either `Facere` or `Facere+` <br> Check out the <a
href="https://rizavelioglu.github.io/fashionfail/">project page</a> for more information."""
article = r"""
Example images are sampled from the `Fashionpedia-test` and `FashionFail-test` set, which the models did not see during training.
<br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and
a citation:
```
@inproceedings{velioglu2024fashionfail,
author = {Velioglu, Riza and Chan, Robin and Hammer, Barbara},
title = {FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation},
journal = {IJCNN},
eprint = {2404.08582},
year = {2024},
}
```
"""
examples = [
["examples/0a4f8205a3b58e70eec99fbbb9422d08.jpg", "facere", 0.5, 0.7],
["examples/0a72e0f76ab9b75945f5d610508f9336.jpg", "facere", 0.5, 0.7],
["examples/0a939e0e67011aecf7195c17ecb9733c.jpg", "facere", 0.5, 0.7],
["examples/adi_9086_5.jpg", "facere", 0.5, 0.5],
["examples/adi_9086_5.jpg", "facere+", 0.5, 0.5],
["examples/adi_9704_1.jpg", "facere", 0.5, 0.5],
["examples/adi_9704_1.jpg", "facere+", 0.5, 0.5],
["examples/adi_10266_5.jpg", "facere", 0.5, 0.5],
["examples/adi_10266_5.jpg", "facere+", 0.5, 0.5],
["examples/adi_103_6.jpg", "facere", 0.5, 0.5],
["examples/adi_103_6.jpg", "facere+", 0.5, 0.5],
["examples/adi_1201_2.jpg", "facere", 0.5, 0.7],
["examples/adi_1201_2.jpg", "facere+", 0.5, 0.7],
["examples/adi_2149_5.jpg", "facere", 0.5, 0.7],
["examples/adi_2149_5.jpg", "facere+", 0.5, 0.7],
["examples/adi_5476_3.jpg", "facere", 0.5, 0.7],
["examples/adi_5476_3.jpg", "facere+", 0.5, 0.7],
["examples/adi_5641_4.jpg", "facere", 0.5, 0.7],
["examples/adi_5641_4.jpg", "facere+", 0.5, 0.7]
]
demo = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="filepath", label="input"),
gr.Dropdown(["facere", "facere+"], value="facere", label="Models"),
gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="Mask threshold", info="a threshold for "
"filtering out "
"low-probability ("
"pixel-wise) mask "
"predictions"),
gr.Slider(value=0.7, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold", info="a threshold for "
"filtering out "
"low-scoring bbox "
"predictions")
],
outputs=gr.Gallery(label="output", preview=True, height=500),
title=title,
description=description,
article=article,
examples=examples,
cache_examples=True,
examples_per_page=6
)
if __name__ == "__main__":
demo.launch()