|
import gradio as gr |
|
from transformers import AutoModel, AutoProcessor |
|
import torch |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
fashion_items = ['top', 'trousers', 'hat', 'jumper'] |
|
|
|
|
|
model_name = 'Marqo/marqo-fashionSigLIP' |
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
processed_texts = processor( |
|
text=fashion_items, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True |
|
)['input_ids'] |
|
|
|
text_features = model.get_text_features(processed_texts) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
def predict_from_url(url): |
|
|
|
if not url: |
|
return {"Error": "Please input a URL"} |
|
|
|
try: |
|
image = Image.open(BytesIO(requests.get(url).content)) |
|
except Exception as e: |
|
return {"Error": f"Failed to load image: {str(e)}"} |
|
|
|
processed_image = processor(images=image, return_tensors="pt")['pixel_values'] |
|
|
|
with torch.no_grad(): |
|
image_features = model.get_image_features(processed_image) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) |
|
|
|
return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))} |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_from_url, |
|
inputs=gr.Textbox(label="Enter Image URL"), |
|
outputs=gr.Label(label="Classification Results"), |
|
title="Fashion Item Classifier", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
demo.launch() |
|
|