Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
from processor import MultiModalProcessor | |
from inference import test_inference | |
from load_model import load_hf_model | |
# Load model and processor | |
MODEL_PATH = "merve/paligemma_vqav2" # or your local model path | |
TOKENIZER_PATH = "./tokenizer" # path to your local tokenizer | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, tokenizer = load_hf_model(MODEL_PATH, TOKENIZER_PATH, device) | |
model = model.eval() | |
num_image_tokens = model.config.vision_config.num_image_tokens | |
image_size = model.config.vision_config.image_size | |
max_length = 512 | |
processor = MultiModalProcessor(tokenizer, num_image_tokens, image_size, max_length) | |
def generate_caption(image, prompt, max_tokens=300, temperature=0.8, top_p=0.9, do_sample=False): | |
# Save the input image temporarily | |
temp_image_path = "temp_image.jpg" | |
Image.fromarray(image).save(temp_image_path) | |
# Use the existing test_inference function | |
result = [] | |
def capture_print(text): | |
result.append(text) | |
import builtins | |
original_print = builtins.print | |
builtins.print = capture_print | |
test_inference( | |
model, | |
processor, | |
device, | |
prompt, | |
temp_image_path, | |
max_tokens, | |
temperature, | |
top_p, | |
do_sample | |
) | |
builtins.print = original_print | |
# Return the captured output | |
return "".join(result) | |
# Define Gradio demo | |
with gr.Blocks(title="Image Captioning with PaliGemma", theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown( | |
""" | |
# Image Captioning with PaliGemma | |
This demo uses the PaliGemma model to generate captions for images. | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Generate Caption"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="numpy", label="Upload Image") | |
prompt_input = gr.Textbox(label="Prompt", placeholder="What is happening in the photo?") | |
with gr.Column(scale=1): | |
with gr.Group(): | |
max_tokens_input = gr.Slider(1, 500, value=300, step=1, label="Max Tokens") | |
temperature_input = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature") | |
top_p_input = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top P") | |
do_sample_input = gr.Checkbox(label="Do Sample") | |
generate_button = gr.Button("Generate Caption") | |
output = gr.Textbox(label="Generated Caption", lines=5) | |
with gr.TabItem("About"): | |
gr.Markdown( | |
""" | |
## How to use: | |
1. Upload an image in the 'Generate Caption' tab. | |
2. Enter a prompt to guide the caption generation. | |
3. Adjust the generation parameters if desired. | |
4. Click 'Generate Caption' to see the results. | |
## Model Details: | |
- Model: PaliGemma | |
- Type: Multimodal (Text + Image) | |
- Task: Image Captioning | |
""" | |
) | |
generate_button.click( | |
generate_caption, | |
inputs=[image_input, prompt_input, max_tokens_input, temperature_input, top_p_input, do_sample_input], | |
outputs=output | |
) | |
# Launch the demo | |
if __name__ == "__main__": | |
demo.launch() |