import gradio as gr from transformers import AutoModel, AutoTokenizer, AutoImageProcessor import torch import torchvision.transforms as T from PIL import Image import logging logging.basicConfig(level=logging.INFO) from torchvision.transforms.functional import InterpolationMode import os from huggingface_hub import login hf_token = os.environ.get('hf_token', None) # Define the models and their paths model_paths = { "H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b", "H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m", # Add more models as needed } def load_model_and_set_image_function(model_name): # Get the model path from the model_paths dictionary model_path = model_paths[model_name] # Load the model model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, use_auth_token=hf_token ).eval().cuda() tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_fast=False, use_auth_token=hf_token ) return model, tokenizer def inference(image_input, user_message, temperature, top_p, max_new_tokens, tile_num, chatbot, state, # image_state, model_state, tokenizer_state): # Check if model_state is None if model_state is None or tokenizer_state is None: chatbot.append(("System", "Please select a model to start the conversation.")) return chatbot, state, "" # Check for empty or invalid user message if not user_message or user_message.strip() == '' or user_message.lower() == 'system': chatbot.append(("System", "Please enter a valid message to continue the conversation.")) return chatbot, state, "" model = model_state tokenizer = tokenizer_state # if image is provided, store it in image_state: if chatbot is None: chatbot = [] if image_input is None: chatbot.append(("System", "Please provide an image to start the conversation.")) return chatbot, state, "" # Initialize history (state) if it's None if state is None: state = None # model.chat function handles None as empty history # Append user message to chatbot chatbot.append((user_message, None)) # Set generation config do_sample = (float(temperature) != 0.0) generation_config = dict( num_beams=1, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature= float(temperature), top_p= float(top_p), ) # Call model.chat with history response_text, new_state = model.chat( tokenizer, image_input, user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, return_history=True ) # update the satet with new_state state = new_state # Update chatbot with the model's response chatbot[-1] = (user_message, response_text) return chatbot, state, "" def regenerate_response(chatbot, temperature, top_p, max_new_tokens, tile_num, state, image_input, model_state, tokenizer_state): # Check if model_state is None if model_state is None or tokenizer_state is None: chatbot.append(("System", "Please select a model to start the conversation.")) return chatbot, state model = model_state tokenizer = tokenizer_state # Check if there is a previous user message if chatbot is None or len(chatbot) == 0: chatbot = [] chatbot.append(("System", "Nothing to regenerate. Please start a conversation first.")) return chatbot, state, # # Check if there is a previous user message # if state is None or len(state) == 0: # chatbot.append(("System", "Nothing to regenerate. Please start a conversation first.")) # return chatbot, state # Get the last user message last_user_message, _ = chatbot[-1] # Check for empty or invalid last user message if not last_user_message or last_user_message.strip() == '' or last_user_message.lower() == 'system': chatbot.append(("System", "Cannot regenerate response for an empty or invalid message.")) return chatbot, state state = state[:-1] # Remove last assistant's response from history if len(state) == 0 or not state: state = None # Set generation config do_sample = (float(temperature) != 0.0) generation_config = dict( num_beams=1, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature= float(temperature), top_p= float(top_p), ) # Regenerate the response response_text, new_state = model.chat( tokenizer, image_input, last_user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, # Exclude last assistant's response return_history=True ) # Update the state with new_state state = new_state # Update chatbot with the regenerated response chatbot.append((last_user_message, response_text)) return chatbot, state def clear_all(): return [], None, None, "" # Clear chatbot, state, reset image_input # Build the Gradio interface with gr.Blocks() as demo: gr.Markdown("# **H2OVL-Mississippi**") state= gr.State() model_state = gr.State() tokenizer_state = gr.State() image_load_function_state = gr.State() with gr.Row(): model_dropdown = gr.Dropdown( choices=list(model_paths.keys()), label="Select Model", value="H2OVL-Mississippi-2B" ) # When the model selection changes, load the new model model_dropdown.change( fn=load_model_and_set_image_function, inputs=[model_dropdown], outputs=[model_state, tokenizer_state] ) with gr.Row(equal_height=True): # First column with image input with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Upload an Image") # Second column with chatbot and user input with gr.Column(scale=2): chatbot = gr.Chatbot(label="Conversation") user_input = gr.Textbox(label="What is your question", placeholder="Type your message here") with gr.Accordion('Parameters', open=False): with gr.Row(): temperature_input = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True, label="Temperature") top_p_input = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.9, interactive=True, label="Top P") max_new_tokens_input = gr.Slider( minimum=0, maximum=4096, step=64, value=1024, interactive=True, label="Max New Tokens (default: 1024)") tile_num = gr.Slider( minimum=2, maximum=12, step=1, value=6, interactive=True, label="Tile Number (default: 6)" ) with gr.Row(): submit_button = gr.Button("Submit") regenerate_button = gr.Button("Regenerate") clear_button = gr.Button("Clear") # When the submit button is clicked, call the inference function submit_button.click( fn=inference, inputs=[ image_input, user_input, temperature_input, top_p_input, max_new_tokens_input, tile_num, chatbot, state, model_state, tokenizer_state ], outputs=[chatbot, state, user_input] ) # When the regenerate button is clicked, re-run the last inference regenerate_button.click( fn=regenerate_response, inputs=[ chatbot, temperature_input, top_p_input, max_new_tokens_input, tile_num, state, image_input, model_state, tokenizer_state, ], outputs=[chatbot, state] ) clear_button.click( fn=clear_all, inputs=None, outputs=[chatbot, state, image_input, user_input] ) gr.Examples( examples=[ ["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"], ["assets/receipt.jpg", "Read the text on the image"], ["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"], ["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"], ], inputs = [image_input, user_input], label = "examples", ) demo.launch(share=True)