Spaces:
Sleeping
Sleeping
File size: 5,656 Bytes
b1f218b 4b6c116 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 8962d34 24a2388 b1f218b 4b6c116 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 02634b8 b1f218b 02634b8 24a2388 02634b8 b1f218b 02634b8 24a2388 b1f218b 02634b8 24a2388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
from ram import get_transform, inference_ram, inference_tag2text
from ram.models import ram, tag2text
ram_checkpoint = "./ram_swin_large_14m.pth"
tag2text_checkpoint = "./tag2text_swin_14m.pth"
image_size = 384
device = "cuda" if torch.cuda.is_available() else "cpu"
@torch.no_grad()
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, transform):
print(f"Start processing, image size {raw_image.size}")
image = transform(raw_image).unsqueeze(0).to(device)
if tagging_model_type == "RAM":
res = inference_ram(image, tagging_model)
tags = res[0].strip(' ').replace(' ', ' ')
tags_chinese = res[1].strip(' ').replace(' ', ' ')
print("Tags: ", tags)
print("标签: ", tags_chinese)
return tags, tags_chinese
else:
res = inference_tag2text(image, tagging_model, specified_tags)
tags = res[0].strip(' ').replace(' ', ' ')
caption = res[2]
print(f"Tags: {tags}")
print(f"Caption: {caption}")
return tags, caption
def inference_with_ram(img):
return inference(img, None, "RAM", ram_model, transform)
def inference_with_t2t(img, input_tags):
return inference(img, input_tags, "Tag2Text", tag2text_model, transform)
if __name__ == "__main__":
import gradio as gr
# get transform and load models
transform = get_transform(image_size=image_size)
ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
tag2text_model = tag2text(
pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
# build GUI
def build_gui():
description = """
<center><strong><font size='10'>Recognize Anything Model</font></strong></center>
<br>
<li>
<b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
</li>
<li>
<b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>! (Optional: Specify tags to get the corresponding caption.)
</li>
""" # noqa
article = """
""" # noqa
with gr.Blocks(title="Recognize Anything Model") as demo:
###############
# components
###############
gr.HTML(description)
with gr.Tab(label="Recognize Anything Model"):
with gr.Row():
with gr.Column():
ram_in_img = gr.Image(type="pil")
with gr.Row():
ram_btn_run = gr.Button(value="Run")
try:
ram_btn_clear = gr.ClearButton()
except AttributeError: # old gradio does not have ClearButton, not big problem
ram_btn_clear = None
with gr.Column():
ram_out_tag = gr.Textbox(label="Tags")
ram_out_biaoqian = gr.Textbox(label="标签")
gr.Examples(
examples=[
["images/demo1.jpg"],
["images/demo2.jpg"],
["images/demo4.jpg"],
],
fn=inference_with_ram,
inputs=[ram_in_img],
outputs=[ram_out_tag, ram_out_biaoqian],
cache_examples=True
)
with gr.Tab(label="Tag2Text Model"):
with gr.Row():
with gr.Column():
t2t_in_img = gr.Image(type="pil")
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
with gr.Row():
t2t_btn_run = gr.Button(value="Run")
try:
t2t_btn_clear = gr.ClearButton()
except AttributeError: # old gradio does not have ClearButton, not big problem
t2t_btn_clear = None
with gr.Column():
t2t_out_tag = gr.Textbox(label="Tags")
t2t_out_cap = gr.Textbox(label="Caption")
gr.Examples(
examples=[
["images/demo4.jpg", ""],
["images/demo4.jpg", "power line"],
["images/demo4.jpg", "track, train"],
],
fn=inference_with_t2t,
inputs=[t2t_in_img, t2t_in_tag],
outputs=[t2t_out_tag, t2t_out_cap],
cache_examples=True
)
gr.HTML(article)
###############
# events
###############
# run inference
ram_btn_run.click(
fn=inference_with_ram,
inputs=[ram_in_img],
outputs=[ram_out_tag, ram_out_biaoqian]
)
t2t_btn_run.click(
fn=inference_with_t2t,
inputs=[t2t_in_img, t2t_in_tag],
outputs=[t2t_out_tag, t2t_out_cap]
)
# clear
if ram_btn_clear is not None:
ram_btn_clear.add([ram_in_img, ram_out_tag, ram_out_biaoqian])
if t2t_btn_clear is not None:
t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_tag, t2t_out_cap])
return demo
build_gui().launch(enable_queue=True)
|