aria / app.py
rockyyyyyy's picture
Create app.py
613bcbb verified
raw
history blame contribute delete
No virus
2.01 kB
import requests
import torch
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
# Load model and processor
model_id_or_path = "rhymes-ai/Aria"
model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
# Function to process the input and generate text
def generate_response(image):
# Convert the input image to PIL format (if necessary)
if isinstance(image, str):
image = Image.open(requests.get(image, stream=True).raw)
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
# Move pixel values to the correct dtype
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
result = processor.decode(output_ids, skip_special_tokens=True)
return result
# Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs=gr.inputs.Image(type="filepath"),
outputs="text",
title="Image-to-Text Model",
description="Upload an image, and the model will describe it.",
)
# Launch the app
iface.launch()