File size: 5,656 Bytes
b1f218b
 
4b6c116
24a2388
 
 
b1f218b
 
24a2388
 
 
b1f218b
24a2388
 
b1f218b
24a2388
 
 
b1f218b
 
24a2388
b1f218b
 
24a2388
 
b1f218b
24a2388
 
 
b1f218b
24a2388
 
b1f218b
 
24a2388
 
b1f218b
 
8962d34
 
24a2388
 
 
b1f218b
 
 
4b6c116
b1f218b
24a2388
 
 
 
 
b1f218b
24a2388
 
 
 
 
b1f218b
24a2388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f218b
 
 
 
24a2388
 
 
 
 
b1f218b
 
 
24a2388
 
b1f218b
 
24a2388
 
 
 
 
 
 
 
 
 
b1f218b
 
 
 
24a2388
 
 
 
 
b1f218b
 
 
24a2388
 
b1f218b
 
24a2388
 
 
 
 
 
 
 
 
 
02634b8
b1f218b
 
02634b8
24a2388
02634b8
b1f218b
 
02634b8
 
24a2388
b1f218b
 
 
 
02634b8
24a2388
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from ram import get_transform, inference_ram, inference_tag2text
from ram.models import ram, tag2text

ram_checkpoint = "./ram_swin_large_14m.pth"
tag2text_checkpoint = "./tag2text_swin_14m.pth"
image_size = 384
device = "cuda" if torch.cuda.is_available() else "cpu"


@torch.no_grad()
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, transform):
    print(f"Start processing, image size {raw_image.size}")

    image = transform(raw_image).unsqueeze(0).to(device)

    if tagging_model_type == "RAM":
        res = inference_ram(image, tagging_model)
        tags = res[0].strip(' ').replace('  ', ' ')
        tags_chinese = res[1].strip(' ').replace('  ', ' ')
        print("Tags: ", tags)
        print("标签: ", tags_chinese)
        return tags, tags_chinese
    else:
        res = inference_tag2text(image, tagging_model, specified_tags)
        tags = res[0].strip(' ').replace('  ', ' ')
        caption = res[2]
        print(f"Tags: {tags}")
        print(f"Caption: {caption}")
        return tags, caption


def inference_with_ram(img):
    return inference(img, None, "RAM", ram_model, transform)


def inference_with_t2t(img, input_tags):
    return inference(img, input_tags, "Tag2Text", tag2text_model, transform)


if __name__ == "__main__":
    import gradio as gr

    # get transform and load models
    transform = get_transform(image_size=image_size)
    ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
    tag2text_model = tag2text(
        pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)

    # build GUI
    def build_gui():

        description = """
            <center><strong><font size='10'>Recognize Anything Model</font></strong></center>
            <br>
            <li>
                <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
            </li>
            <li>
                <b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>! (Optional: Specify tags to get the corresponding caption.)
            </li>
        """  # noqa

        article = """
        """  # noqa

        with gr.Blocks(title="Recognize Anything Model") as demo:
            ###############
            # components
            ###############
            gr.HTML(description)

            with gr.Tab(label="Recognize Anything Model"):
                with gr.Row():
                    with gr.Column():
                        ram_in_img = gr.Image(type="pil")
                        with gr.Row():
                            ram_btn_run = gr.Button(value="Run")
                            try:
                                ram_btn_clear = gr.ClearButton()
                            except AttributeError:  # old gradio does not have ClearButton, not big problem
                                ram_btn_clear = None
                    with gr.Column():
                        ram_out_tag = gr.Textbox(label="Tags")
                        ram_out_biaoqian = gr.Textbox(label="标签")
                gr.Examples(
                    examples=[
                        ["images/demo1.jpg"],
                        ["images/demo2.jpg"],
                        ["images/demo4.jpg"],
                    ],
                    fn=inference_with_ram,
                    inputs=[ram_in_img],
                    outputs=[ram_out_tag, ram_out_biaoqian],
                    cache_examples=True
                )

            with gr.Tab(label="Tag2Text Model"):
                with gr.Row():
                    with gr.Column():
                        t2t_in_img = gr.Image(type="pil")
                        t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
                        with gr.Row():
                            t2t_btn_run = gr.Button(value="Run")
                            try:
                                t2t_btn_clear = gr.ClearButton()
                            except AttributeError:  # old gradio does not have ClearButton, not big problem
                                t2t_btn_clear = None
                    with gr.Column():
                        t2t_out_tag = gr.Textbox(label="Tags")
                        t2t_out_cap = gr.Textbox(label="Caption")
                gr.Examples(
                    examples=[
                        ["images/demo4.jpg", ""],
                        ["images/demo4.jpg", "power line"],
                        ["images/demo4.jpg", "track, train"],
                    ],
                    fn=inference_with_t2t,
                    inputs=[t2t_in_img, t2t_in_tag],
                    outputs=[t2t_out_tag, t2t_out_cap],
                    cache_examples=True
                )

            gr.HTML(article)

            ###############
            # events
            ###############
            # run inference
            ram_btn_run.click(
                fn=inference_with_ram,
                inputs=[ram_in_img],
                outputs=[ram_out_tag, ram_out_biaoqian]
            )
            t2t_btn_run.click(
                fn=inference_with_t2t,
                inputs=[t2t_in_img, t2t_in_tag],
                outputs=[t2t_out_tag, t2t_out_cap]
            )

            # clear
            if ram_btn_clear is not None:
                ram_btn_clear.add([ram_in_img, ram_out_tag, ram_out_biaoqian])
            if t2t_btn_clear is not None:
                t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_tag, t2t_out_cap])

        return demo

    build_gui().launch(enable_queue=True)