File size: 2,152 Bytes
d40f2f8
f2df246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e8f4c8
f2df246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c904757
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import subprocess

import torch
import gradio as gr
from clip_interrogator import Config, Interrogator


CACHE_URLS = [
    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
]
os.makedirs('cache', exist_ok=True)
for url in CACHE_URLS:
    subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8')


config = Config()
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
config.blip_offload = False if torch.cuda.is_available() else True
config.chunk_size = 2048
config.flavor_intermediate_count = 512
config.blip_num_beams = 64
ci = Interrogator(config)

#@spaces.GPU
def inference(image, mode, best_max_flavors):
    image = image.convert('RGB')
    if mode == 'best':
        prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors))
    elif mode == 'classic':
        prompt_result = ci.interrogate_classic(image)
    else:
        prompt_result = ci.interrogate_fast(image)
    return prompt_result


with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# CLIP Interrogator")
        input_image = gr.Image(type='pil', elem_id="input-img")
        with gr.Row():
            mode_input = gr.Radio(['best', 'classic', 'fast'], label='Select mode', value='best')
            flavor_input = gr.Slider(minimum=2, maximum=48, step=2, value=32, label='best mode max flavors')
        submit_btn = gr.Button("Submit")
        output_text = gr.Textbox(label="Description Output")
    submit_btn.click(
        fn=inference,
        inputs=[input_image, mode_input, flavor_input],
        outputs=[output_text],
        concurrency_limit=10
    )

#demo.launch(server_name="0.0.0.0")
demo.queue().launch()