|
import gradio as gr |
|
from typing import List |
|
from PIL import Image |
|
from zhclip import ZhCLIPProcessor, ZhCLIPModel |
|
|
|
version = 'thu-ml/zh-clip-vit-roberta-large-patch14' |
|
model = ZhCLIPModel.from_pretrained(version) |
|
processor = ZhCLIPProcessor.from_pretrained(version) |
|
|
|
def inference(image, texts): |
|
texts = [x[0] for x in texts] |
|
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) |
|
outputs = model(**inputs) |
|
image_features = outputs.image_features |
|
text_features = outputs.text_features |
|
text_probs = (image_features @ text_features.T).softmax(dim=-1)[0].detach().cpu().numpy() |
|
return {i: float(text_probs[i]) for i in range(len(text_probs))} |
|
|
|
title = "ZH-CLIP zero-shot classification" |
|
description = "Chinese Clip Model (ZH-CLIP) zero-shot classification" |
|
article="<p style='text-align: center'><a href='https://www.github.com/thu-ml/zh-clip' target='_blank'>github: zh-clip</a> <a href='https://huggingface.co/thu-ml/zh-clip-vit-roberta-large-patch14' target='_blank'>huggingface model: thu-ml/zh-clip-vit-roberta-large-patch14</a></p>" |
|
examples = [['./images/dog.jpeg', [['一只狗'], ['一只猫']]]] |
|
interpretation='default' |
|
enable_queue=True |
|
|
|
iface = gr.Interface(fn=inference, inputs=["image", "list"], outputs="label", |
|
title=title, description=description, article=article, examples=examples, |
|
enable_queue=enable_queue) |
|
iface.launch(server_name='0.0.0.0') |
|
|