lukiod commited on
Commit
7901fac
1 Parent(s): 7a5741e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
4
+ from PIL import Image
5
+ from byaldi import RAGMultiModalModel
6
+ from qwen_vl_utils import process_vision_info
7
+
8
+ # Model and processor names
9
+ RAG_MODEL = "vidore/colpali"
10
+ QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
11
+ QWN_PROCESSOR = "Qwen/Qwen2-VL-2B-Instruct"
12
+
13
+ @st.cache_resource
14
+ def load_models():
15
+ RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ QWN_MODEL,
19
+ torch_dtype=torch.bfloat16,
20
+ trust_remote_code=True
21
+ ).cuda().eval()
22
+
23
+ processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
24
+ tokenizer = AutoTokenizer.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
25
+
26
+ return RAG, model, processor, tokenizer
27
+
28
+ RAG, model, processor, tokenizer = load_models()
29
+
30
+ def document_rag(text_query, image):
31
+ messages = [
32
+ {
33
+ "role": "user",
34
+ "content": [
35
+ {
36
+ "type": "image",
37
+ "image": image,
38
+ },
39
+ {"type": "text", "text": text_query},
40
+ ],
41
+ }
42
+ ]
43
+ text = tokenizer.apply_chat_template(
44
+ messages, tokenize=False, add_generation_prompt=True
45
+ )
46
+ image_inputs, video_inputs = process_vision_info(messages)
47
+ inputs = processor(
48
+ text=[text],
49
+ images=image_inputs,
50
+ videos=video_inputs,
51
+ padding=True,
52
+ return_tensors="pt",
53
+ )
54
+ inputs = inputs.to("cuda")
55
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
56
+ generated_ids_trimmed = [
57
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
58
+ ]
59
+ output_text = tokenizer.batch_decode(
60
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
61
+ )
62
+ return output_text[0]
63
+
64
+ st.title("Document Processor")
65
+
66
+ uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
67
+ text_query = st.text_input("Enter your text query")
68
+
69
+ if uploaded_file is not None and text_query:
70
+ image = Image.open(uploaded_file)
71
+
72
+ if st.button("Process Document"):
73
+ with st.spinner("Processing..."):
74
+ result = document_rag(text_query, image)
75
+ st.success("Processing complete!")
76
+ st.write("Result:", result)