import numpy as np
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption, ram
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225
])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
#######Tag2Text Model
pretrained = 'tag2text_swin_14m.pth'
model_tag2text = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
model_tag2text.eval()
model_tag2text = model_tag2text.to(device)
#######RAM Model
pretrained = 'ram_swin_large_14m.pth'
model_ram = ram(pretrained=pretrained, image_size=image_size, vit='swin_l' )
model_ram.eval()
model_ram = model_ram.to(device)
def inference(raw_image, model_n , input_tag):
raw_image = raw_image.resize((image_size, image_size))
image = transform(raw_image).unsqueeze(0).to(device)
if model_n == 'Recognize Anything Model':
model = model_ram
with torch.no_grad():
tags, tags_chinese = model.generate_tag(image)
return tags[0],tags_chinese[0], 'none'
else:
model = model_tag2text
model.threshold = 0.68
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
input_tag_list = None
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',',' | '))
with torch.no_grad():
caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
tag_2 = tag_predict
return tag_1[0],'none',caption[0]
def build_gui():
description = """
Recognize Anything Model
Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo!
Recognize Anything Model: Upload your image to get the English and Chinese outputs of the image tags!
Tag2Text Model: Upload your image to get the tags and caption of the image.
Optional: You can also input specified tags to get the corresponding caption.
""" # noqa
article = """
RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.
Recognize Anything: A Strong Image Tagging Model
|
Tag2Text: Guiding Language-Image Model via Image Tagging
|
Github Repo
""" # noqa
def inference_with_ram(img):
res = inference(img, "Recognize Anything Model", None)
return res[0], res[1]
def inference_with_t2t(img, input_tags):
res = inference(img, "Tag2Text Model", input_tags)
return res[0], res[2]
with gr.Blocks(title="Recognize Anything Model") as demo:
###############
# components
###############
gr.HTML(description)
with gr.Tab(label="Recognize Anything Model"):
with gr.Row():
with gr.Column():
ram_in_img = gr.Image(type="pil")
with gr.Row():
ram_btn_run = gr.Button(value="Run")
ram_btn_clear = gr.Button(value="Clear")
with gr.Column():
ram_out_tag = gr.Textbox(label="Tags")
ram_out_biaoqian = gr.Textbox(label="标签")
gr.Examples(
examples=[
["images/demo1.jpg"],
["images/demo2.jpg"],
["images/demo4.jpg"],
],
fn=inference_with_ram,
inputs=[ram_in_img],
outputs=[ram_out_tag, ram_out_biaoqian],
cache_examples=True
)
with gr.Tab(label="Tag2Text Model"):
with gr.Row():
with gr.Column():
t2t_in_img = gr.Image(type="pil")
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
with gr.Row():
t2t_btn_run = gr.Button(value="Run")
t2t_btn_clear = gr.Button(value="Clear")
with gr.Column():
t2t_out_tag = gr.Textbox(label="Tags")
t2t_out_cap = gr.Textbox(label="Caption")
gr.Examples(
examples=[
["images/demo4.jpg", ""],
["images/demo4.jpg", "power line"],
["images/demo4.jpg", "track, train"],
],
fn=inference_with_t2t,
inputs=[t2t_in_img, t2t_in_tag],
outputs=[t2t_out_tag, t2t_out_cap],
cache_examples=True
)
gr.HTML(article)
###############
# events
###############
# run inference
ram_btn_run.click(
fn=inference_with_ram,
inputs=[ram_in_img],
outputs=[ram_out_tag, ram_out_biaoqian]
)
t2t_btn_run.click(
fn=inference_with_t2t,
inputs=[t2t_in_img, t2t_in_tag],
outputs=[t2t_out_tag, t2t_out_cap]
)
# # images of two image panels should keep the same
# # and clear old outputs when image changes
# # slow due to internet latency when deployed on huggingface, comment out
# def sync_img(v):
# return [gr.update(value=v)] + [gr.update(value="")] * 4
# ram_in_img.upload(fn=sync_img, inputs=[ram_in_img], outputs=[
# t2t_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
# ])
# ram_in_img.clear(fn=sync_img, inputs=[ram_in_img], outputs=[
# t2t_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
# ])
# t2t_in_img.clear(fn=sync_img, inputs=[t2t_in_img], outputs=[
# ram_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
# ])
# t2t_in_img.upload(fn=sync_img, inputs=[t2t_in_img], outputs=[
# ram_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
# ])
# clear all
def clear_all():
return [gr.update(value=None)] * 2 + [gr.update(value="")] * 5
ram_btn_clear.click(fn=clear_all, inputs=[], outputs=[
ram_in_img, t2t_in_img,
ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
])
t2t_btn_clear.click(fn=clear_all, inputs=[], outputs=[
ram_in_img, t2t_in_img,
ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
])
return demo
if __name__ == "__main__":
demo = build_gui()
demo.launch(enable_queue=True)