lukiod commited on
Commit
0ae684c
1 Parent(s): fff6204

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -42
app.py CHANGED
@@ -1,62 +1,124 @@
1
  import streamlit as st
 
 
 
2
  from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
- from PIL import Image
5
- import torch
 
 
 
 
 
 
 
 
 
6
 
7
- # Load the model and processor
8
  @st.cache_resource
9
- def load_model():
10
- # Load Qwen2-VL-7B on CPU
11
  model = Qwen2VLForConditionalGeneration.from_pretrained(
12
  "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
13
  )
14
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
15
  return model, processor
16
 
17
- model, processor = load_model()
 
 
 
18
 
19
  # Streamlit Interface
20
- st.title("Qwen2-VL-7B Multimodal Demo")
21
- st.write("Upload an image and provide a text prompt to see the model's response.")
22
 
23
  # Image uploader
24
  image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
25
 
26
- # Text input field
27
- text = st.text_input("Enter a text description or query")
28
-
29
- # If both image and text are provided
30
- if image and text:
31
- # Load image with PIL
32
  img = Image.open(image)
33
  st.image(img, caption="Uploaded Image", use_column_width=True)
34
 
35
- # Prepare inputs for Qwen2-VL
36
- messages = [
37
- {
38
- "role": "user",
39
- "content": [
40
- {"type": "image", "image": img},
41
- {"type": "text", "text": text},
42
- ],
43
- }
44
- ]
45
-
46
- # Prepare for inference
47
- text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
- image_inputs, _ = process_vision_info(messages)
49
- inputs = processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
50
-
51
- # Move tensors to CPU
52
- inputs = inputs.to("cpu")
53
-
54
- # Run the model and generate output
55
- with torch.no_grad():
56
- generated_ids = model.generate(**inputs, max_new_tokens=128)
57
-
58
- # Decode the output text
59
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
60
-
61
- # Display the response
62
- st.write("Model's response:", generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import gc
5
  from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
6
  from qwen_vl_utils import process_vision_info
7
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
8
+ from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
9
+ from torch.utils.data import DataLoader
10
+
11
+ # Function to load Colpali model
12
+ @st.cache_resource
13
+ def load_colpali_model():
14
+ model = ColPali.from_pretrained("vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.float32, device_map="cpu").eval()
15
+ model.load_adapter("vidore/colpali")
16
+ processor = AutoProcessor.from_pretrained("vidore/colpali")
17
+ return model, processor
18
 
19
+ # Function to load Qwen2-VL model
20
  @st.cache_resource
21
+ def load_qwen_model():
 
22
  model = Qwen2VLForConditionalGeneration.from_pretrained(
23
  "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float32, device_map="cpu"
24
  )
25
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
26
  return model, processor
27
 
28
+ # Function to clear GPU memory
29
+ def clear_memory():
30
+ gc.collect()
31
+ torch.cuda.empty_cache()
32
 
33
  # Streamlit Interface
34
+ st.title("OCR and Visual Language Model Demo")
35
+ st.write("Upload an image for OCR extraction and then ask a question about the image.")
36
 
37
  # Image uploader
38
  image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
39
 
40
+ if image:
 
 
 
 
 
41
  img = Image.open(image)
42
  st.image(img, caption="Uploaded Image", use_column_width=True)
43
 
44
+ # OCR Extraction with Colpali
45
+ st.write("Extracting text from image...")
46
+ colpali_model, colpali_processor = load_colpali_model()
47
+
48
+ # Process image for Colpali
49
+ dataloader = DataLoader(
50
+ [img],
51
+ batch_size=1,
52
+ shuffle=False,
53
+ collate_fn=lambda x: process_images(colpali_processor, x),
54
+ )
55
+
56
+ for batch_doc in dataloader:
57
+ with torch.no_grad():
58
+ batch_doc = {k: v.to('cpu') for k, v in batch_doc.items()}
59
+ embeddings_doc = colpali_model(**batch_doc)
60
+
61
+ # For simplicity, we'll use a dummy query to extract text
62
+ dummy_query = "Extract all text from the image"
63
+ query_dataloader = DataLoader(
64
+ [dummy_query],
65
+ batch_size=1,
66
+ shuffle=False,
67
+ collate_fn=lambda x: process_queries(colpali_processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
68
+ )
69
+
70
+ for batch_query in query_dataloader:
71
+ with torch.no_grad():
72
+ batch_query = {k: v.to('cpu') for k, v in batch_query.items()}
73
+ embeddings_query = colpali_model(**batch_query)
74
+
75
+ # In a real scenario, you'd use these embeddings to extract text
76
+ # For this demo, we'll just show a placeholder text
77
+ extracted_text = "This is a placeholder for the extracted text. In a real scenario, you would use the embeddings to extract actual text from the image."
78
+
79
+ st.write("Extracted Text:")
80
+ st.write(extracted_text)
81
+
82
+ # Clear Colpali model from memory
83
+ del colpali_model, colpali_processor
84
+ clear_memory()
85
+
86
+ # Text input field for question
87
+ question = st.text_input("Ask a question about the image and extracted text")
88
+
89
+ if question:
90
+ st.write("Processing with Qwen2-VL...")
91
+ qwen_model, qwen_processor = load_qwen_model()
92
+
93
+ # Prepare inputs for Qwen2-VL
94
+ messages = [
95
+ {
96
+ "role": "user",
97
+ "content": [
98
+ {"type": "image", "image": img},
99
+ {"type": "text", "text": f"Extracted text: {extracted_text}\n\nQuestion: {question}"},
100
+ ],
101
+ }
102
+ ]
103
+
104
+ # Prepare for inference
105
+ text_input = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
106
+ image_inputs, _ = process_vision_info(messages)
107
+ inputs = qwen_processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
108
+
109
+ # Move tensors to CPU
110
+ inputs = inputs.to("cpu")
111
+
112
+ # Run the model and generate output
113
+ with torch.no_grad():
114
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=128)
115
+
116
+ # Decode the output text
117
+ generated_text = qwen_processor.batch_decode(generated_ids, skip_special_tokens=True)
118
+
119
+ # Display the response
120
+ st.write("Model's response:", generated_text)
121
+
122
+ # Clear Qwen model from memory
123
+ del qwen_model, qwen_processor
124
+ clear_memory()