lukiod commited on
Commit
be730b6
1 Parent(s): 9353556

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -42
app.py CHANGED
@@ -2,19 +2,15 @@ import streamlit as st
2
  import torch
3
  from PIL import Image
4
  import gc
5
- from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
6
  from qwen_vl_utils import process_vision_info
7
- from colpali_engine.models.paligemma_colbert_architecture import ColPali
8
- from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
9
- from torch.utils.data import DataLoader
10
 
11
- # Function to load Colpali model
12
  @st.cache_resource
13
- def load_colpali_model():
14
- model = ColPali.from_pretrained("vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.float32, device_map="cpu").eval()
15
- model.load_adapter("vidore/colpali")
16
- processor = AutoProcessor.from_pretrained("vidore/colpali")
17
- return model, processor
18
 
19
  # Function to load Qwen2-VL model
20
  @st.cache_resource
@@ -41,46 +37,28 @@ if image:
41
  img = Image.open(image)
42
  st.image(img, caption="Uploaded Image", use_column_width=True)
43
 
44
- # OCR Extraction with Colpali
45
  st.write("Extracting text from image...")
46
- colpali_model, colpali_processor = load_colpali_model()
47
-
48
- # Process image for Colpali
49
- dataloader = DataLoader(
50
- [img],
51
- batch_size=1,
52
- shuffle=False,
53
- collate_fn=lambda x: process_images(colpali_processor, x),
54
- )
55
 
56
- for batch_doc in dataloader:
57
- with torch.no_grad():
58
- batch_doc = {k: v.to('cpu') for k, v in batch_doc.items()}
59
- embeddings_doc = colpali_model(**batch_doc)
60
 
61
- # For simplicity, we'll use a dummy query to extract text
62
- dummy_query = "Extract all text from the image"
63
- query_dataloader = DataLoader(
64
- [dummy_query],
65
- batch_size=1,
66
- shuffle=False,
67
- collate_fn=lambda x: process_queries(colpali_processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
68
- )
69
-
70
- for batch_query in query_dataloader:
71
- with torch.no_grad():
72
- batch_query = {k: v.to('cpu') for k, v in batch_query.items()}
73
- embeddings_query = colpali_model(**batch_query)
74
 
75
- # In a real scenario, you'd use these embeddings to extract text
76
- # For this demo, we'll just show a placeholder text
77
- extracted_text = "This is a placeholder for the extracted text. In a real scenario, you would use the embeddings to extract actual text from the image."
 
 
78
 
79
  st.write("Extracted Text:")
80
  st.write(extracted_text)
81
 
82
- # Clear Colpali model from memory
83
- del colpali_model, colpali_processor
84
  clear_memory()
85
 
86
  # Text input field for question
 
2
  import torch
3
  from PIL import Image
4
  import gc
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
  from qwen_vl_utils import process_vision_info
7
+ from byaldi import RAGMultiModalModel
 
 
8
 
9
+ # Function to load Byaldi model
10
  @st.cache_resource
11
+ def load_byaldi_model():
12
+ model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device="cpu")
13
+ return model
 
 
14
 
15
  # Function to load Qwen2-VL model
16
  @st.cache_resource
 
37
  img = Image.open(image)
38
  st.image(img, caption="Uploaded Image", use_column_width=True)
39
 
40
+ # OCR Extraction with Byaldi
41
  st.write("Extracting text from image...")
42
+ byaldi_model = load_byaldi_model()
 
 
 
 
 
 
 
 
43
 
44
+ # Create a temporary index for the uploaded image
45
+ with st.spinner("Processing image..."):
46
+ byaldi_model.index(img, index_name="temp_index", overwrite=True)
 
47
 
48
+ # Perform a dummy search to get the OCR results
49
+ ocr_results = byaldi_model.search("Extract all text from the image", k=1)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Extract the OCR text from the results
52
+ if ocr_results:
53
+ extracted_text = ocr_results[0].metadata.get("ocr_text", "No text extracted")
54
+ else:
55
+ extracted_text = "No text extracted"
56
 
57
  st.write("Extracted Text:")
58
  st.write(extracted_text)
59
 
60
+ # Clear Byaldi model from memory
61
+ del byaldi_model
62
  clear_memory()
63
 
64
  # Text input field for question