import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from fromage import models
from fromage import utils
import gradio as gr
import huggingface_hub
from share_btn import community_icon_html, loading_icon_html, share_js
import tempfile
css = """
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
margin-top: 10px;
margin-left: auto;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
"""
# Download model from HF Hub.
ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
model = models.load_fromage('./', args_path, ckpt_path)
def upload_image(state, image_input):
conversation = state[0]
chat_history = state[1]
conversation += [(f"![](/file={image_input.name})", "")]
input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
return [conversation, chat_history, input_image], conversation
def reset():
return [[], [], None], []
def save_image_to_local(image: Image.Image):
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
filename = next(tempfile._get_candidate_names()) + '.png'
image.save(filename)
return filename
def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
input_prompt = 'Q: ' + input_text + '\nA:'
conversation = state[0]
chat_history = state[1]
input_image = state[2]
print('Generating for', chat_history, flush=True)
# If an image was uploaded, prepend it to the model.
model_inputs = None
if input_image is not None:
model_inputs = chat_history + [input_image]
else:
model_inputs = chat_history
model_inputs.append(input_prompt)
top_p = 1.0
if temperature != 0.0:
top_p = 0.95
print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
model_outputs = model.generate_for_images_and_texts(model_inputs,
num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
temperature=temperature, max_num_rets=max_nm_rets)
print('model_outputs', model_outputs, flush=True)
im_names = []
response = ''
text_outputs = []
for output in model_outputs:
if type(output) == str:
text_outputs.append(output)
response += output
elif type(output) == list:
for image in output:
filename = save_image_to_local(image)
response += f'
'
elif type(output) == Image.Image:
filename = save_image_to_local(output)
response += f'
'
# TODO(jykoh): Persist image inputs.
chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
conversation.append((input_text, response.replace('[RET]', ''))) # Remove [RET] from outputs.
# Set input image to None.
print('state', state, flush=True)
print('updated state', [conversation, chat_history, None], flush=True)
return [conversation, chat_history, None], conversation
with gr.Blocks(css=css) as demo:
gr.Markdown(
'### Grounding Language Models to Images for Multimodal Generation'
)
gr.HTML("""
For faster inference without waiting in queue, you may duplicate the space and use your own GPU.
""")
chatbot = gr.Chatbot(elem_id="chatbot")
gr_state = gr.State([[], [], None]) # chat_history, input_image
with gr.Group(elem_id="share-btn-container", visible=False):
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
with gr.Row():
with gr.Column(scale=0.3, min_width=100):
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
with gr.Column(scale=0.7, min_width=400):
image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
text_input = gr.Textbox(label="Chat Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
with gr.Row():
with gr.Column(scale=0.5):
submit_btn = gr.Button("Submit", interactive=True, variant="primary")
with gr.Column(scale=0.5):
clear_btn = gr.Button("Clear History")
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
clear_btn.click(reset, [], [gr_state, chatbot])
share_button.click(None, [], [], _js=share_js)
demo.queue(concurrency_count=1, api_open=False, max_size=16)
demo.launch(debug=True, server_name="0.0.0.0")