import gradio as gr from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig from PIL import Image import torch import spaces # Model name and arguments repo_name = "cyan2k/molmo-7B-D-bnb-4bit" arguments = {"device_map": "auto", "torch_dtype": "auto", "trust_remote_code": True} # Load the processor and model processor = AutoProcessor.from_pretrained(repo_name, **arguments) model = AutoModelForCausalLM.from_pretrained(repo_name, **arguments) # Predefined prompts prompts = [ "Describe this image in detail", "What objects can you see in this image?", "What's the main subject of this image?", "Describe the colors in this image", "What emotions does this image evoke?" ] def process_image_and_text(image, text, max_new_tokens, temperature, top_p): # Process the image and text inputs = processor.process( images=[Image.fromarray(image)], text=text ) # Move inputs to the correct device and make a batch of size 1 inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} # Generate output output = model.generate_from_batch( inputs, GenerationConfig( max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, stop_strings="<|endoftext|>" ), tokenizer=processor.tokenizer ) # Only get generated tokens; decode them to text generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) return generated_text def chatbot(image, text, history, max_new_tokens, temperature, top_p): if image is None: return history + [("Please upload an image first.", None)] response = process_image_and_text(image, text, max_new_tokens, temperature, top_p) history.append((text, response)) return history def update_textbox(prompt): return gr.update(value=prompt) # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Chatbot with Molmo-7B-D-0924") with gr.Row(): image_input = gr.Image(type="numpy") chatbot_output = gr.Chatbot() with gr.Row(): text_input = gr.Textbox(placeholder="Ask a question about the image...") prompt_dropdown = gr.Dropdown(choices=[""] + prompts, label="Select a premade prompt", value="") submit_button = gr.Button("Submit") clear_button = gr.ClearButton([text_input, chatbot_output]) with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(minimum=1, maximum=500, value=200, step=1, label="Max new tokens") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") state = gr.State([]) # Add copy button for raw output with gr.Row(): raw_output = gr.Textbox(label="Raw Output", interactive=False) copy_button = gr.Button("Copy Raw Output") def update_raw_output(history): if history: return history[-1][1] return "" submit_button.click( chatbot, inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p], outputs=[chatbot_output] ).then( update_raw_output, inputs=[chatbot_output], outputs=[raw_output] ) text_input.submit( chatbot, inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p], outputs=[chatbot_output] ).then( update_raw_output, inputs=[chatbot_output], outputs=[raw_output] ) prompt_dropdown.change(update_textbox, inputs=[prompt_dropdown], outputs=[text_input]) copy_button.click(lambda x: gr.update(value=x), inputs=[raw_output], outputs=[gr.Textbox(visible=False)]) demo.launch()