vlm-o / app.py
veerpareek's picture
Upload 35 files
577d9ca verified
raw
history blame
3.5 kB
import gradio as gr
import torch
from PIL import Image
from processor import MultiModalProcessor
from inference import test_inference
from load_model import load_hf_model
# Load model and processor
MODEL_PATH = "merve/paligemma_vqav2" # or your local model path
TOKENIZER_PATH = "./tokenizer" # path to your local tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer = load_hf_model(MODEL_PATH, TOKENIZER_PATH, device)
model = model.eval()
num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
max_length = 512
processor = MultiModalProcessor(tokenizer, num_image_tokens, image_size, max_length)
def generate_caption(image, prompt, max_tokens=300, temperature=0.8, top_p=0.9, do_sample=False):
# Save the input image temporarily
temp_image_path = "temp_image.jpg"
Image.fromarray(image).save(temp_image_path)
# Use the existing test_inference function
result = []
def capture_print(text):
result.append(text)
import builtins
original_print = builtins.print
builtins.print = capture_print
test_inference(
model,
processor,
device,
prompt,
temp_image_path,
max_tokens,
temperature,
top_p,
do_sample
)
builtins.print = original_print
# Return the captured output
return "".join(result)
# Define Gradio demo
with gr.Blocks(title="Image Captioning with PaliGemma", theme=gr.themes.Monochrome()) as demo:
gr.Markdown(
"""
# Image Captioning with PaliGemma
This demo uses the PaliGemma model to generate captions for images.
"""
)
with gr.Tabs():
with gr.TabItem("Generate Caption"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="Upload Image")
prompt_input = gr.Textbox(label="Prompt", placeholder="What is happening in the photo?")
with gr.Column(scale=1):
with gr.Group():
max_tokens_input = gr.Slider(1, 500, value=300, step=1, label="Max Tokens")
temperature_input = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature")
top_p_input = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top P")
do_sample_input = gr.Checkbox(label="Do Sample")
generate_button = gr.Button("Generate Caption")
output = gr.Textbox(label="Generated Caption", lines=5)
with gr.TabItem("About"):
gr.Markdown(
"""
## How to use:
1. Upload an image in the 'Generate Caption' tab.
2. Enter a prompt to guide the caption generation.
3. Adjust the generation parameters if desired.
4. Click 'Generate Caption' to see the results.
## Model Details:
- Model: PaliGemma
- Type: Multimodal (Text + Image)
- Task: Image Captioning
"""
)
generate_button.click(
generate_caption,
inputs=[image_input, prompt_input, max_tokens_input, temperature_input, top_p_input, do_sample_input],
outputs=output
)
# Launch the demo
if __name__ == "__main__":
demo.launch()