Spaces:
Runtime error
Runtime error
import gradio as gr | |
from off_topic import OffTopicDetector, Translator | |
translator = Translator("facebook/nllb-200-distilled-600M") | |
detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator) | |
def validate(item_id: str, use_title: bool, threshold: float): | |
images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title) | |
valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold] | |
invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold] | |
return f"## Domain: {domain}", valid_images, invalid_images | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Off topic image detector | |
### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed. | |
Input an item ID or select one of the preloaded examples below.""") | |
with gr.Row(): | |
item_id = gr.Textbox(label="Item ID") | |
with gr.Column(): | |
use_title = gr.Checkbox(label="Use translated item title", value=True) | |
threshold = gr.Number(label="Threshold", value=0.25, precision=2) | |
submit = gr.Button("Submit") | |
gr.HTML("<hr>") | |
domain = gr.Markdown() | |
valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto") | |
gr.HTML("<hr>") | |
invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto") | |
submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate) | |
gr.HTML("<hr>") | |
gr.Examples( | |
examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25], | |
["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]], | |
inputs=[item_id, use_title, threshold], | |
outputs=[domain, valid, invalid], | |
fn=validate, | |
cache_examples=True, | |
) | |
demo.launch() | |