import tempfile from share_btn import community_icon_html, loading_icon_html, share_js, save_js import huggingface_hub import gradio as gr from fromage import utils from fromage import models import matplotlib.pyplot as plt from PIL import Image import torch import numpy as np import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" css = """ #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; margin-top: 3px; margin-left: auto; flex: unset; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } #chatbot { min-height: 300px; } #save-btn { background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); } #save-btn:hover { background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); } #share-btn-2 { background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); } #share-btn-2:hover { background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); } .message .user { } .message .bot { } """ examples = [ ] # Download model from HF Hub. ckpt_path = huggingface_hub.hf_hub_download( repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar') args_path = huggingface_hub.hf_hub_download( repo_id='jykoh/fromage', filename='model_args.json') model = models.load_fromage('./', args_path, ckpt_path) def upload_image(state, image_input): conversation = state[0] chat_history = state[1] input_image = Image.open(image_input.name).resize( (224, 224)).convert('RGB') input_image.save(image_input.name) # Overwrite with smaller image. conversation += [(f"![](/file={image_input.name})", "")] return [conversation, chat_history, input_image], conversation def reset(): return [[], [], None], [] def reset_last(state): conversation = state[0][:-1] chat_history = state[1][:-2] input_image = state[2] return [conversation, chat_history, input_image], conversation def save_image_to_local(image: Image.Image): # TODO(jykoh): Update so the url path is used, to prevent repeat saving. filename = next(tempfile._get_candidate_names()) + '.png' image.save(filename) return filename def generate_for_prompt(input_text, state, ret_scale_factor, max_num_rets, num_words, temperature): # Ignore empty inputs. if len(input_text) == 0: return state, state[0], gr.update(visible=True) input_prompt = 'Q: ' + input_text + '\nA:' conversation = state[0] chat_history = state[1] input_image = state[2] print('Generating for', chat_history, flush=True) # If an image was uploaded, prepend it to the model. model_inputs = None if input_image is not None: model_inputs = chat_history + [input_image] else: model_inputs = chat_history model_inputs.append(input_prompt) top_p = 1.0 if temperature != 0.0: top_p = 0.95 print('Running model.generate_for_images_and_texts with', model_inputs, flush=True) model_outputs = model.generate_for_images_and_texts(model_inputs, num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p, temperature=temperature, max_num_rets=max_num_rets) print('model_outputs', model_outputs, flush=True) im_names = [] response = '' text_outputs = [] for output in model_outputs: if type(output) == str: text_outputs.append(output) response += output elif type(output) == list: response += '
' # Add line break between images. for image in output: filename = save_image_to_local(image) response += f'' response += '
' elif type(output) == Image.Image: filename = save_image_to_local(output) response += '
' response += f'' response += '
' # TODO(jykoh): Persist image inputs. chat_history = model_inputs + \ [' '.join([s for s in model_outputs if type(s) == str]) + '\n'] # Remove [RET] from outputs. conversation.append((input_text, response.replace('[RET]', ''))) # Set input image to None. print('state', state, flush=True) print('updated state', [conversation, chat_history, None], flush=True) return [conversation, chat_history, None], conversation, gr.update(visible=True), gr.update(visible=True) with gr.Blocks(css=css) as demo: gr.Markdown( '### Grounding Language Models to Images for Multimodal Generation' ) gr.HTML(""" For faster inference without waiting in queue, you may duplicate the space and use your own GPU. Duplicate Space """) gr_state = gr.State([[], [], None]) # chat_history, input_image with gr.Row(): with gr.Column(scale=0.7, min_width=500): with gr.Row(): chatbot = gr.Chatbot(elem_id="chatbot", label="FROMAGe Chatbot") with gr.Row(): image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"]) text_input = gr.Textbox(label="Message", placeholder="Type a message") with gr.Column(): submit_btn = gr.Button( "Submit", interactive=True, variant="primary") clear_last_btn = gr.Button("Undo") clear_btn = gr.Button("Reset All") with gr.Row(visible=False) as save_group: save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn") with gr.Row(visible=False) as share_group: share_button = gr.Button("🤗 Share to Community", elem_id="share-btn-2") with gr.Column(scale=0.3, min_width=200): ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Frequency multiplier for returning images (higher means more frequent)") max_ret_images = gr.Number( minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return") gr_max_len = gr.Slider(minimum=1, maximum=64, value=32, step=1, interactive=True, label="Max # of words returned") gr_temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)") # gallery = gr.Gallery( # value=examples, label="Example Conversations", show_label=True, elem_id="gallery", # ).style(grid=[2], height="auto") text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) text_input.submit(lambda: "", None, text_input) # Reset chatbox. submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) submit_btn.click(lambda: "", None, text_input) # Reset chatbox. image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot]) clear_btn.click(reset, [], [gr_state, chatbot]) share_button.click(None, [], [], _js=share_js) save_button.click(None, [], [], _js=save_js) demo.queue(concurrency_count=1, api_open=False, max_size=16) demo.launch(debug=True, server_name="0.0.0.0")