lukiod's picture
Update app.py
0ae684c verified
raw
history blame
4.49 kB
import streamlit as st
import torch
from PIL import Image
import gc
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from torch.utils.data import DataLoader
# Function to load Colpali model
@st.cache_resource
def load_colpali_model():
model = ColPali.from_pretrained("vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.float32, device_map="cpu").eval()
model.load_adapter("vidore/colpali")
processor = AutoProcessor.from_pretrained("vidore/colpali")
return model, processor
# Function to load Qwen2-VL model
@st.cache_resource
def load_qwen_model():
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
return model, processor
# Function to clear GPU memory
def clear_memory():
gc.collect()
torch.cuda.empty_cache()
# Streamlit Interface
st.title("OCR and Visual Language Model Demo")
st.write("Upload an image for OCR extraction and then ask a question about the image.")
# Image uploader
image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if image:
img = Image.open(image)
st.image(img, caption="Uploaded Image", use_column_width=True)
# OCR Extraction with Colpali
st.write("Extracting text from image...")
colpali_model, colpali_processor = load_colpali_model()
# Process image for Colpali
dataloader = DataLoader(
[img],
batch_size=1,
shuffle=False,
collate_fn=lambda x: process_images(colpali_processor, x),
)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to('cpu') for k, v in batch_doc.items()}
embeddings_doc = colpali_model(**batch_doc)
# For simplicity, we'll use a dummy query to extract text
dummy_query = "Extract all text from the image"
query_dataloader = DataLoader(
[dummy_query],
batch_size=1,
shuffle=False,
collate_fn=lambda x: process_queries(colpali_processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
)
for batch_query in query_dataloader:
with torch.no_grad():
batch_query = {k: v.to('cpu') for k, v in batch_query.items()}
embeddings_query = colpali_model(**batch_query)
# In a real scenario, you'd use these embeddings to extract text
# For this demo, we'll just show a placeholder text
extracted_text = "This is a placeholder for the extracted text. In a real scenario, you would use the embeddings to extract actual text from the image."
st.write("Extracted Text:")
st.write(extracted_text)
# Clear Colpali model from memory
del colpali_model, colpali_processor
clear_memory()
# Text input field for question
question = st.text_input("Ask a question about the image and extracted text")
if question:
st.write("Processing with Qwen2-VL...")
qwen_model, qwen_processor = load_qwen_model()
# Prepare inputs for Qwen2-VL
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": f"Extracted text: {extracted_text}\n\nQuestion: {question}"},
],
}
]
# Prepare for inference
text_input = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = qwen_processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
# Move tensors to CPU
inputs = inputs.to("cpu")
# Run the model and generate output
with torch.no_grad():
generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)
# Decode the output text
generated_text = qwen_processor.batch_decode(generated_ids, skip_special_tokens=True)
# Display the response
st.write("Model's response:", generated_text)
# Clear Qwen model from memory
del qwen_model, qwen_processor
clear_memory()