#!/usr/bin/env python3
import gradio as gr
from clip_interrogator import Config, Interrogator
# MODELS = ['ViT-L (best for Stable Diffusion 1.*)', 'ViT-H (best for Stable Diffusion 2.*)']
# MODELS = ['ViT-L (best for Stable Diffusion 1.*)',]
# load BLIP and ViT-L https://huggingface.co/openai/clip-vit-large-patch14
from PIL import Image
from clip_interrogator import Config, Interrogator
ci = Interrogator(Config(clip_model_name="ViT-L-14/openai"))
def image_analysis(image):
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode):
image = image.convert('RGB')
if mode == 'best':
prompt = ci.interrogate(image)
elif mode == 'classic':
prompt = ci.interrogate_classic(image)
elif mode == 'fast':
prompt = ci.interrogate_fast(image)
elif mode == 'negative':
prompt = ci.interrogate_negative(image)
return prompt
TITLE = """
CLIP Interrogator
Want to figure out what a good prompt might be to create new images like an existing one?
The CLIP Interrogator is here to get you answers!
You can skip the queue by duplicating this space and upgrading to gpu in settings:
"""
ARTICLE = """
"""
CSS = """
#col-container {margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
"""
def analyze_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
model = gr.Dropdown(MODELS, value=MODELS[0], label='CLIP Model')
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
button = gr.Button("Analyze")
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
examples=[['example01.jpg', MODELS[0]], ['example02.jpg', MODELS[0]]]
ex = gr.Examples(
examples=examples,
fn=image_analysis,
inputs=[input_image],
outputs=[medium, artist, movement, trending, flavor],
cache_examples=True,
run_on_click=True
)
ex.dataset.headers = [""]
with gr.Blocks(css=CSS) as block:
with gr.Column(elem_id="col-container"):
gr.HTML(TITLE)
with gr.Tab("Prompt"):
with gr.Row():
input_image = gr.Image(type='pil', elem_id="input-img")
with gr.Column():
input_mode = gr.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Mode')
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Output", elem_id="output-txt")
examples=[['example01.jpg', 'best'], ['example02.jpg', 'best']]
ex = gr.Examples(
examples=examples,
fn=image_to_prompt,
inputs=[input_image, input_mode],
outputs=[output_text],
cache_examples=True,
run_on_click=True
)
ex.dataset.headers = [""]
with gr.Tab("Analyze"):
analyze_tab()
gr.HTML(ARTICLE)
submit_btn.click(
fn=image_to_prompt,
inputs=[input_image, input_mode],
outputs=[output_text]
)
block.queue(max_size=64).launch(show_api=False)