Shanshan Wang
set default models
da76dba
raw
history blame
9.9 kB
import gradio as gr
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
import torch
import torchvision.transforms as T
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
from torchvision.transforms.functional import InterpolationMode
import os
from huggingface_hub import login
hf_token = os.environ.get('hf_token', None)
# Define the models and their paths
model_paths = {
"H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b",
"H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m",
# Add more models as needed
}
def load_model_and_set_image_function(model_name):
# Get the model path from the model_paths dictionary
model_path = model_paths[model_name]
# Load the model
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_auth_token=hf_token
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=False,
use_auth_token=hf_token
)
return model, tokenizer
def inference(image_input,
user_message,
temperature,
top_p,
max_new_tokens,
tile_num,
chatbot,
state,
# image_state,
model_state,
tokenizer_state):
# Check if model_state is None
if model_state is None or tokenizer_state is None:
chatbot.append(("System", "Please select a model to start the conversation."))
return chatbot, state, ""
# Check for empty or invalid user message
if not user_message or user_message.strip() == '' or user_message.lower() == 'system':
chatbot.append(("System", "Please enter a valid message to continue the conversation."))
return chatbot, state, ""
model = model_state
tokenizer = tokenizer_state
# if image is provided, store it in image_state:
if chatbot is None:
chatbot = []
if image_input is None:
chatbot.append(("System", "Please provide an image to start the conversation."))
return chatbot, state, ""
# Initialize history (state) if it's None
if state is None:
state = None # model.chat function handles None as empty history
# Append user message to chatbot
chatbot.append((user_message, None))
# Set generation config
do_sample = (float(temperature) != 0.0)
generation_config = dict(
num_beams=1,
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature= float(temperature),
top_p= float(top_p),
)
# Call model.chat with history
response_text, new_state = model.chat(
tokenizer,
image_input,
user_message,
max_tiles = int(tile_num),
generation_config=generation_config,
history=state,
return_history=True
)
# update the satet with new_state
state = new_state
# Update chatbot with the model's response
chatbot[-1] = (user_message, response_text)
return chatbot, state, ""
def regenerate_response(chatbot,
temperature,
top_p,
max_new_tokens,
tile_num,
state,
image_input,
model_state,
tokenizer_state):
# Check if model_state is None
if model_state is None or tokenizer_state is None:
chatbot.append(("System", "Please select a model to start the conversation."))
return chatbot, state
model = model_state
tokenizer = tokenizer_state
# Check if there is a previous user message
if chatbot is None or len(chatbot) == 0:
chatbot = []
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
return chatbot, state,
# # Check if there is a previous user message
# if state is None or len(state) == 0:
# chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
# return chatbot, state
# Get the last user message
last_user_message, _ = chatbot[-1]
# Check for empty or invalid last user message
if not last_user_message or last_user_message.strip() == '' or last_user_message.lower() == 'system':
chatbot.append(("System", "Cannot regenerate response for an empty or invalid message."))
return chatbot, state
state = state[:-1] # Remove last assistant's response from history
if len(state) == 0 or not state:
state = None
# Set generation config
do_sample = (float(temperature) != 0.0)
generation_config = dict(
num_beams=1,
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature= float(temperature),
top_p= float(top_p),
)
# Regenerate the response
response_text, new_state = model.chat(
tokenizer,
image_input,
last_user_message,
max_tiles = int(tile_num),
generation_config=generation_config,
history=state, # Exclude last assistant's response
return_history=True
)
# Update the state with new_state
state = new_state
# Update chatbot with the regenerated response
chatbot.append((last_user_message, response_text))
return chatbot, state
def clear_all():
return [], None, None, "" # Clear chatbot, state, reset image_input
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# **H2OVL-Mississippi**")
state= gr.State()
model_state = gr.State()
tokenizer_state = gr.State()
image_load_function_state = gr.State()
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(model_paths.keys()),
label="Select Model",
value="H2OVL-Mississippi-2B"
)
# When the model selection changes, load the new model
model_dropdown.change(
fn=load_model_and_set_image_function,
inputs=[model_dropdown],
outputs=[model_state, tokenizer_state]
)
with gr.Row(equal_height=True):
# First column with image input
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Upload an Image")
# Second column with chatbot and user input
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Conversation")
user_input = gr.Textbox(label="What is your question", placeholder="Type your message here")
with gr.Accordion('Parameters', open=False):
with gr.Row():
temperature_input = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.2,
interactive=True,
label="Temperature")
top_p_input = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.9,
interactive=True,
label="Top P")
max_new_tokens_input = gr.Slider(
minimum=0,
maximum=4096,
step=64,
value=1024,
interactive=True,
label="Max New Tokens (default: 1024)")
tile_num = gr.Slider(
minimum=2,
maximum=12,
step=1,
value=6,
interactive=True,
label="Tile Number (default: 6)"
)
with gr.Row():
submit_button = gr.Button("Submit")
regenerate_button = gr.Button("Regenerate")
clear_button = gr.Button("Clear")
# When the submit button is clicked, call the inference function
submit_button.click(
fn=inference,
inputs=[
image_input,
user_input,
temperature_input,
top_p_input,
max_new_tokens_input,
tile_num,
chatbot,
state,
model_state,
tokenizer_state
],
outputs=[chatbot, state, user_input]
)
# When the regenerate button is clicked, re-run the last inference
regenerate_button.click(
fn=regenerate_response,
inputs=[
chatbot,
temperature_input,
top_p_input,
max_new_tokens_input,
tile_num,
state,
image_input,
model_state,
tokenizer_state,
],
outputs=[chatbot, state]
)
clear_button.click(
fn=clear_all,
inputs=None,
outputs=[chatbot, state, image_input, user_input]
)
gr.Examples(
examples=[
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
["assets/receipt.jpg", "Read the text on the image"],
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"],
],
inputs = [image_input, user_input],
label = "examples",
)
demo.launch(share=True)