import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import threading import os # caching the mode model_cache = {} tokenizer_cache = {} model_lock = threading.Lock() from huggingface_hub import login hf_token = os.environ.get('hf_token', None) # Define the models and their paths model_paths = { "H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b", "H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m", # Add more models as needed } example_prompts = [ "Read the text and provide word by word ocr for the document. ", "Extract the text from the image.", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount", ] def load_model_and_set_image_function(model_name): # Get the model path from the model_paths dictionary model_path = model_paths[model_name] with model_lock: if model_name in model_cache: # model is already loaded; retrieve it from the cache print(f"Model {model_name} is already loaded. Retrieving from cache.") else: # load the model and tokenizer print(f"Loading model {model_name}...") model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, use_auth_token=hf_token, # device_map="auto" ).eval().cuda() tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_fast=False, use_auth_token=hf_token ) # add the model and tokenizer to the cache model_cache[model_name] = model tokenizer_cache[model_name] = tokenizer print(f"Model {model_name} loaded successfully.") return model_name def inference(image_input, user_message, temperature, top_p, max_new_tokens, tile_num, chatbot, state, model_name): # Check if model_state is None if model_name is None: chatbot.append(("System", "Please select a model to start the conversation.")) return chatbot, state, "" with model_lock: if model_name not in model_cache: chatbot.append(("System", "Model not loaded. Please wait for the model to load.")) return chatbot, state, "" model = model_cache[model_name] tokenizer = tokenizer_cache[model_name] # Check for empty or invalid user message if not user_message or user_message.strip() == '' or user_message.lower() == 'system': chatbot.append(("System", "Please enter a valid message to continue the conversation.")) return chatbot, state, "" # if image is provided, store it in image_state: if chatbot is None: chatbot = [] if image_input is None: chatbot.append(("System", "Please provide an image to start the conversation.")) return chatbot, state, "" # Initialize history (state) if it's None if state is None: state = None # model.chat function handles None as empty history # Append user message to chatbot chatbot.append((user_message, None)) # Set generation config do_sample = (float(temperature) != 0.0) generation_config = dict( num_beams=1, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature= float(temperature), top_p= float(top_p), ) # Call model.chat with history response_text, new_state = model.chat( tokenizer, image_input, user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, return_history=True ) # update the satet with new_state state = new_state # Update chatbot with the model's response chatbot[-1] = (user_message, response_text) return chatbot, state, "" def regenerate_response(chatbot, temperature, top_p, max_new_tokens, tile_num, state, image_input, model_name): # Check if model_state is None if model_name is None: chatbot.append(("System", "Please select a model to start the conversation.")) return chatbot, state with model_lock: if model_name not in model_cache: chatbot.append(("System", "Model not loaded. Please wait for the model to load.")) return chatbot, state model = model_cache[model_name] tokenizer = tokenizer_cache[model_name] # Check if there is a previous user message if chatbot is None or len(chatbot) == 0: chatbot = [] chatbot.append(("System", "Nothing to regenerate. Please start a conversation first.")) return chatbot, state, # Get the last user message last_user_message, _ = chatbot[-1] # Check for empty or invalid last user message if not last_user_message or last_user_message.strip() == '' or last_user_message.lower() == 'system': chatbot.append(("System", "Cannot regenerate response for an empty or invalid message.")) return chatbot, state # Remove last assistant's response from state if state is not None and len(state) > 0: state = state[:-1] # Remove last assistant's response from history if len(state) == 0: state = None else: state = None # Set generation config do_sample = (float(temperature) != 0.0) generation_config = dict( num_beams=1, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature= float(temperature), top_p= float(top_p), ) # Regenerate the response response_text, new_state = model.chat( tokenizer, image_input, last_user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, # Exclude last assistant's response return_history=True ) # Update the state with new_state state = new_state # Update chatbot with the regenerated response chatbot[-1] = (last_user_message, response_text) return chatbot, state def clear_all(): return [], None, None, "" # Clear chatbot, state, reset image_input # Build the Gradio interface with gr.Blocks() as demo: gr.Markdown("# **H2OVL-Mississippi**") state= gr.State() model_state = gr.State() with gr.Row(): model_dropdown = gr.Dropdown( choices=list(model_paths.keys()), label="Select Model", value="H2OVL-Mississippi-2B" ) with gr.Row(equal_height=True): # First column with image input with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Upload an Image") # Second column with chatbot and user input with gr.Column(scale=2): chatbot = gr.Chatbot(label="Conversation") user_input = gr.Dropdown(label="What is your question", choices = example_prompts, value=None, allow_custom_value=True, interactive=True) def reset_chatbot_state(): # reset chatbot and state return [], None # When the model selection changes, load the new model model_dropdown.change( fn=load_model_and_set_image_function, inputs=[model_dropdown], outputs=[model_state] ) model_dropdown.change( fn=reset_chatbot_state, inputs=None, outputs=[chatbot, state] ) # Reset chatbot and state when image input changes image_input.change( fn=reset_chatbot_state, inputs=None, outputs=[chatbot, state] ) # Load the default model when the app starts demo.load( fn=load_model_and_set_image_function, inputs=[model_dropdown], outputs=[model_state] ) with gr.Accordion('Parameters', open=False): with gr.Row(): temperature_input = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True, label="Temperature") top_p_input = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.9, interactive=True, label="Top P") max_new_tokens_input = gr.Slider( minimum=0, maximum=4096, step=64, value=1024, interactive=True, label="Max New Tokens (default: 1024)") tile_num = gr.Slider( minimum=2, maximum=12, step=1, value=6, interactive=True, label="Tile Number (default: 6)" ) with gr.Row(): submit_button = gr.Button("Submit") regenerate_button = gr.Button("Regenerate") clear_button = gr.Button("Clear") # When the submit button is clicked, call the inference function submit_button.click( fn=inference, inputs=[ image_input, user_input, temperature_input, top_p_input, max_new_tokens_input, tile_num, chatbot, state, model_state ], outputs=[chatbot, state, user_input] ) # When the regenerate button is clicked, re-run the last inference regenerate_button.click( fn=regenerate_response, inputs=[ chatbot, temperature_input, top_p_input, max_new_tokens_input, tile_num, state, image_input, model_state ], outputs=[chatbot, state] ) clear_button.click( fn=clear_all, inputs=None, outputs=[chatbot, state, image_input, user_input] ) def example_clicked(image_value, user_input_value): chatbot_value, state_value = [], None return image_value, user_input_value, chatbot_value, state_value # Reset chatbot and state gr.Examples( examples=[ ["assets/handwritten-note-example.jpg", "Read the text on the image"], ["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"], ["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"], ["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"], ], inputs = [image_input, user_input], outputs = [image_input, user_input, chatbot, state], fn=example_clicked, label = "examples", ) demo.queue() demo.launch(max_threads=10)