zamal's picture
Update app.py
a3978ee verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from io import BytesIO
from PIL import Image
import spaces # Import spaces for ZeroGPU support
# Load the model and processor
model_path = "deepseek-ai/deepseek-vl-1.3b-chat"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
# Define the function for image description with ZeroGPU support
@spaces.GPU # Ensures GPU allocation for this function
def describe_image(image, user_question="Describe this image in great detail."):
try:
# Convert the PIL Image to a BytesIO object for compatibility
image_byte_arr = BytesIO()
image.save(image_byte_arr, format="PNG") # Save image in PNG format
image_byte_arr.seek(0) # Move pointer to the start
# Define the conversation, using the user's question
conversation = [
{
"role": "User",
"content": f"<image_placeholder>{user_question}",
"images": [image_byte_arr] # Pass the image byte array instead of an object
},
{
"role": "Assistant",
"content": ""
}
]
# Convert image byte array back to a PIL image for processing
pil_images = [Image.open(BytesIO(image_byte_arr.read()))] # Convert byte back to PIL Image
image_byte_arr.seek(0) # Reset the byte stream again for reuse
# Load images and prepare the inputs
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to('cuda')
# Load and prepare the model
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda().eval()
# Generate embeddings from the image input
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# Generate the model's response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)
# Decode the generated tokens into text
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
except Exception as e:
# Provide detailed error information
return f"Error: {str(e)}"
# Gradio interface
def gradio_app():
with gr.Blocks() as demo:
gr.Markdown("# Image Description with DeepSeek VL 1.3b 🐬\n### Upload an image and ask a question about it.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
question_input = gr.Textbox(
label="Question (optional)",
placeholder="Ask a question about the image (e.g., 'What is happening in this image?')",
lines=2
)
output_text = gr.Textbox(label="Image Description", interactive=False)
submit_btn = gr.Button("Generate Description")
submit_btn.click(
fn=describe_image,
inputs=[image_input, question_input], # Pass both image and question as inputs
outputs=output_text
)
demo.launch()
# Launch the Gradio app
gradio_app()