import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import threading import os # caching the mode model_cache = {} tokenizer_cache = {} model_lock = threading.Lock() from huggingface_hub import login hf_token = os.environ.get('hf_token', None) # Define the models and their paths model_paths = { "H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b", "H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m", # Add more models as needed } example_prompts = [ "Read the text and provide word by word ocr for the document. ", "Read the text on the image", "Extract the text from the image.", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount", ] def load_model_and_set_image_function(model_name): # Get the model path from the model_paths dictionary model_path = model_paths[model_name] with model_lock: if model_name in model_cache: # model is already loaded; retrieve it from the cache print(f"Model {model_name} is already loaded. Retrieving from cache.") else: # load the model and tokenizer print(f"Loading model {model_name}...") model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, use_auth_token=hf_token, # device_map="auto" ).eval().cuda() tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_fast=False, use_auth_token=hf_token ) # add the model and tokenizer to the cache model_cache[model_name] = model tokenizer_cache[model_name] = tokenizer print(f"Model {model_name} loaded successfully.") return model_name def inference(image_input, user_message, temperature, top_p, max_new_tokens, tile_num, chatbot, state, model_name): # Check if model_state is None if model_name is None: chatbot.append(("System", "Please select a model to start the conversation.")) return chatbot, state, "" with model_lock: if model_name not in model_cache: chatbot.append(("System", "Model not loaded. Please wait for the model to load.")) return chatbot, state, "" model = model_cache[model_name] tokenizer = tokenizer_cache[model_name] # Check for empty or invalid user message if not user_message or user_message.strip() == '' or user_message.lower() == 'system': chatbot.append(("System", "Please enter a valid message to continue the conversation.")) return chatbot, state, "" # if image is provided, store it in image_state: if chatbot is None: chatbot = [] if image_input is None: chatbot.append(("System", "Please provide an image to start the conversation.")) return chatbot, state, "" # Initialize history (state) if it's None if state is None: state = None # model.chat function handles None as empty history # Append user message to chatbot chatbot.append((user_message, None)) # Set generation config do_sample = (float(temperature) != 0.0) generation_config = dict( num_beams=1, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature= float(temperature), top_p= float(top_p), ) # Call model.chat with history if '2b' in model_name.lower(): response_text, new_state = model.chat( tokenizer, image_input, user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, return_history=True ) if '0.8b' in model_name.lower(): response_text, new_state = model.ocr( tokenizer, image_input, user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, return_history=True ) # update the satet with new_state state = new_state # Update chatbot with the model's response chatbot[-1] = (user_message, response_text) return chatbot, state, "" def regenerate_response(chatbot, temperature, top_p, max_new_tokens, tile_num, state, image_input, model_name): # Check if model_state is None if model_name is None: chatbot.append(("System", "Please select a model to start the conversation.")) return chatbot, state with model_lock: if model_name not in model_cache: chatbot.append(("System", "Model not loaded. Please wait for the model to load.")) return chatbot, state model = model_cache[model_name] tokenizer = tokenizer_cache[model_name] # Check if there is a previous user message if chatbot is None or len(chatbot) == 0: chatbot = [] chatbot.append(("System", "Nothing to regenerate. Please start a conversation first.")) return chatbot, state, # Get the last user message last_user_message, _ = chatbot[-1] # Check for empty or invalid last user message if not last_user_message or last_user_message.strip() == '' or last_user_message.lower() == 'system': chatbot.append(("System", "Cannot regenerate response for an empty or invalid message.")) return chatbot, state # Remove last assistant's response from state if state is not None and len(state) > 0: state = state[:-1] # Remove last assistant's response from history if len(state) == 0: state = None else: state = None # Set generation config do_sample = (float(temperature) != 0.0) generation_config = dict( num_beams=1, max_new_tokens=int(max_new_tokens), do_sample=do_sample, temperature= float(temperature), top_p= float(top_p), ) # Regenerate the response if '2b' in model_name.lower(): response_text, new_state = model.chat( tokenizer, image_input, last_user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, # Exclude last assistant's response return_history=True ) if '0.8b' in model_name.lower(): response_text, new_state = model.ocr( tokenizer, image_input, last_user_message, max_tiles = int(tile_num), generation_config=generation_config, history=state, # Exclude last assistant's response return_history=True ) # Update the state with new_state state = new_state # Update chatbot with the regenerated response chatbot[-1] = (last_user_message, response_text) return chatbot, state def clear_all(): return [], None, None, "" # Clear chatbot, state, reset image_input # Build the Gradio interface with gr.Blocks() as demo: gr.Markdown("# **H2OVL-Mississippi**") state= gr.State() model_state = gr.State() with gr.Row(): model_dropdown = gr.Dropdown( choices=list(model_paths.keys()), label="Select Model", value="H2OVL-Mississippi-2B" ) with gr.Row(equal_height=True): # First column with image input with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Upload an Image") # Second column with chatbot and user input with gr.Column(scale=2): chatbot = gr.Chatbot(label="Conversation") user_input = gr.Dropdown(label="What is your question", choices = example_prompts, value=None, allow_custom_value=True, interactive=True) def reset_chatbot_state(): # reset chatbot and state return [], None # When the model selection changes, load the new model model_dropdown.change( fn=load_model_and_set_image_function, inputs=[model_dropdown], outputs=[model_state] ) model_dropdown.change( fn=reset_chatbot_state, inputs=None, outputs=[chatbot, state] ) # Reset chatbot and state when image input changes image_input.change( fn=reset_chatbot_state, inputs=None, outputs=[chatbot, state] ) # Load the default model when the app starts demo.load( fn=load_model_and_set_image_function, inputs=[model_dropdown], outputs=[model_state] ) with gr.Accordion('Parameters', open=False): with gr.Row(): temperature_input = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True, label="Temperature") top_p_input = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.9, interactive=True, label="Top P") max_new_tokens_input = gr.Slider( minimum=64, maximum=4096, step=64, value=1024, interactive=True, label="Max New Tokens (default: 1024)") tile_num = gr.Slider( minimum=2, maximum=12, step=1, value=6, interactive=True, label="Tile Number (default: 6)" ) with gr.Row(): submit_button = gr.Button("Submit") regenerate_button = gr.Button("Regenerate") clear_button = gr.Button("Clear") # When the submit button is clicked, call the inference function submit_button.click( fn=inference, inputs=[ image_input, user_input, temperature_input, top_p_input, max_new_tokens_input, tile_num, chatbot, state, model_state ], outputs=[chatbot, state, user_input] ) # When the regenerate button is clicked, re-run the last inference regenerate_button.click( fn=regenerate_response, inputs=[ chatbot, temperature_input, top_p_input, max_new_tokens_input, tile_num, state, image_input, model_state ], outputs=[chatbot, state] ) clear_button.click( fn=clear_all, inputs=None, outputs=[chatbot, state, image_input, user_input] ) def example_clicked(image_value, user_input_value): chatbot_value, state_value = [], None return image_value, user_input_value, chatbot_value, state_value # Reset chatbot and state gr.Examples( examples=[ ["assets/handwritten-note-example.jpg", "Read the text on the image"], ["assets/receipt.jpg", "Extract the text from the image."], ["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"], ["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"], ["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"], ], inputs = [image_input, user_input], outputs = [image_input, user_input, chatbot, state], fn=example_clicked, label = "examples", ) demo.queue() demo.launch(max_threads=10)