Prathmesh48 commited on
Commit
a5bb707
1 Parent(s): 348ea05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -36
app.py CHANGED
@@ -1,37 +1,305 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import AutoTokenizer, AutoModel
4
-
5
- # Load the tokenizer and model
6
- @st.cache_resource
7
- def load_model():
8
- tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
9
- model = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
10
- model.to('cpu')
11
- return tokenizer, model
12
-
13
- tokenizer, model = load_model()
14
-
15
- def extract_embeddings(text, tokenizer, model):
16
- # Tokenize the input text
17
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
18
- inputs = {k: v.to('cpu') for k, v in inputs.items()}
19
-
20
- # Get the model's outputs
21
- with torch.no_grad():
22
- outputs = model(**inputs)
23
-
24
- # Extract the embeddings (use the output of the last hidden state)
25
- embeddings = outputs.last_hidden_state.mean(dim=1)
26
-
27
- return embeddings.squeeze().cpu().numpy()
28
-
29
- # Streamlit app
30
- st.title("Text Embeddings Extractor")
31
-
32
- text = st.text_area("Enter text to extract embeddings:", "This is an example sentence.")
33
-
34
- if st.button("Extract Embeddings"):
35
- embeddings = extract_embeddings(text, tokenizer, model)
36
- st.write("Embeddings:")
37
- st.write(embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import concurrent.futures
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from functools import partial
5
+ import numpy as np
6
+ from io import StringIO
7
+ import sys
8
+ import time
9
+ import pandas as pd
10
+ from pymongo import MongoClient
11
+ import plotly.express as px
12
+ from pinecone import Pinecone, ServerlessSpec
13
+ import chromadb
14
+ import requests
15
+ from io import BytesIO
16
+ from PyPDF2 import PdfReader
17
+ import hashlib
18
+ import os
19
+
20
+ # File Imports
21
+ from embedding import get_embeddings, get_image_embeddings, get_embed_chroma , imporve_text # Ensure this file/module is available
22
+ from preprocess import filtering # Ensure this file/module is available
23
+ from search import *
24
+
25
+
26
+ # Chroma Connections
27
+ client = chromadb.PersistentClient(path="embeddings")
28
+ collection = client.get_or_create_collection(name="data", metadata={"hnsw:space": "l2"})
29
+
30
+
31
+ def generate_hash(content):
32
+ return hashlib.sha256(content.encode('utf-8')).hexdigest()
33
+
34
+
35
+ def get_key(link):
36
+ text = ''
37
+ try:
38
+ # Fetch the PDF file from the URL
39
+ response = requests.get(link)
40
+ response.raise_for_status() # Raise an error for bad status codes
41
+
42
+ # Use BytesIO to handle the PDF content in memory
43
+ pdf_file = BytesIO(response.content)
44
+
45
+ # Load the PDF file
46
+ reader = PdfReader(pdf_file)
47
+ num_pages = len(reader.pages)
48
+
49
+ first_page_text = reader.pages[0].extract_text()
50
+ if first_page_text:
51
+ text += first_page_text
52
+
53
+ last_page_text = reader.pages[-1].extract_text()
54
+ if last_page_text:
55
+ text += last_page_text
56
+
57
+ except requests.exceptions.HTTPError as e:
58
+ print(f'HTTP error occurred: {e}')
59
+ except Exception as e:
60
+ print(f'An error occurred: {e}')
61
+
62
+ unique_key = generate_hash(text)
63
+
64
+ return unique_key
65
+
66
+
67
+ # Cosine Similarity Function
68
+ def cosine_similarity(vec1, vec2):
69
+ vec1 = np.array(vec1)
70
+ vec2 = np.array(vec2)
71
+
72
+ dot_product = np.dot(vec1, vec2.T)
73
+ magnitude_vec1 = np.linalg.norm(vec1)
74
+ magnitude_vec2 = np.linalg.norm(vec2)
75
+
76
+ if magnitude_vec1 == 0 or magnitude_vec2 == 0:
77
+ return 0.0
78
+
79
+ cosine_sim = dot_product / (magnitude_vec1 * magnitude_vec2)
80
+ return cosine_sim
81
+
82
+
83
+ def update_chroma(product_name, url, key, text, vector, log_area):
84
+ id_list = [key + str(i) for i in range(len(text))]
85
+
86
+ metadata_list = [
87
+ {'key': key,
88
+ 'product_name': product_name,
89
+ 'url': url,
90
+ 'text': item
91
+ }
92
+ for item in text
93
+ ]
94
+
95
+ collection.upsert(
96
+ ids=id_list,
97
+ embeddings=vector,
98
+ metadatas=metadata_list
99
+ )
100
+
101
+ logger.write(f"\n\u2713 Updated DB - {url}\n\n")
102
+ log_area.text(logger.getvalue())
103
+
104
+
105
+ # Logger class to capture output
106
+ class StreamCapture:
107
+ def __init__(self):
108
+ self.output = StringIO()
109
+ self._stdout = sys.stdout
110
+
111
+ def __enter__(self):
112
+ sys.stdout = self.output
113
+ return self.output
114
+
115
+ def __exit__(self, exc_type, exc_val, exc_tb):
116
+ sys.stdout = self._stdout
117
+
118
+
119
+ # Main Function
120
+ def score(main_product, main_url, product_count, link_count, search, logger, log_area):
121
+ data = {}
122
+ similar_products = extract_similar_products(main_product)[:product_count]
123
+
124
+ print("--> Fetching Manual Links")
125
+ # Normal Filtering + Embedding -----------------------------------------------
126
+ if search == 'All':
127
+
128
+ def process_product(product, search_function, main_product):
129
+ search_result = search_function(product)
130
+ return filtering(search_result, main_product, product, link_count)
131
+
132
+ search_functions = {
133
+ 'google': search_google,
134
+ 'duckduckgo': search_duckduckgo,
135
+ 'github': search_github,
136
+ 'wikipedia': search_wikipedia
137
+ }
138
+
139
+ with ThreadPoolExecutor() as executor:
140
+ future_to_product_search = {
141
+ executor.submit(process_product, product, search_function, main_product): (product, search_name)
142
+ for product in similar_products
143
+ for search_name, search_function in search_functions.items()
144
+ }
145
+
146
+ for future in as_completed(future_to_product_search):
147
+ product, search_name = future_to_product_search[future]
148
+ try:
149
+ if product not in data:
150
+ data[product] = {}
151
+ data[product] = future.result()
152
+ except Exception as e:
153
+ print(f"Error processing product {product} with {search_name}: {e}")
154
+
155
+ else:
156
+
157
+ for product in similar_products:
158
+
159
+ if search == 'google':
160
+ data[product] = filtering(search_google(product), main_product, product, link_count)
161
+ elif search == 'duckduckgo':
162
+ data[product] = filtering(search_duckduckgo(product), main_product, product, link_count)
163
+ elif search == 'archive':
164
+ data[product] = filtering(search_archive(product), main_product, product, link_count)
165
+ elif search == 'github':
166
+ data[product] = filtering(search_github(product), main_product, product, link_count)
167
+ elif search == 'wikipedia':
168
+ data[product] = filtering(search_wikipedia(product), main_product, product, link_count)
169
+
170
+ # Filtered Link -----------------------------------------
171
+ logger.write("\n\n\u2713 Filtered Links\n")
172
+ log_area.text(logger.getvalue())
173
+
174
+ # Main product Embeddings ---------------------------------
175
+ logger.write("\n\n--> Creating Main product Embeddings\n")
176
+
177
+ main_key = get_key(main_url)
178
+ main_text, main_vector = get_embed_chroma(main_url)
179
+
180
+ update_chroma(main_product, main_url, main_key, main_text, main_vector, log_area)
181
+
182
+ # log_area.text(logger.getvalue())
183
+ print("\n\n\u2713 Main Product embeddings Created")
184
+
185
+ logger.write("\n\n--> Creating Similar product Embeddings\n")
186
+ log_area.text(logger.getvalue())
187
+ test_embedding = [0] * 768
188
+
189
+ for product in data:
190
+ for link in data[product]:
191
+
192
+ url, _ = link
193
+ similar_key = get_key(url)
194
+
195
+ res = collection.query(
196
+ query_embeddings=[test_embedding],
197
+ n_results=1,
198
+ where={"key": similar_key},
199
+ )
200
+
201
+ if not res['distances'][0]:
202
+ similar_text, similar_vector = get_embed_chroma(url)
203
+ update_chroma(product, url, similar_key, similar_text, similar_vector, log_area)
204
+
205
+ logger.write("\n\n\u2713 Similar Product embeddings Created\n")
206
+ log_area.text(logger.getvalue())
207
+
208
+ top_similar = []
209
+
210
+ for idx, chunk in enumerate(main_vector):
211
+ res = collection.query(
212
+ query_embeddings=[chunk],
213
+ n_results=1,
214
+ where={"key": {'$ne': main_key}},
215
+ include=['metadatas', 'embeddings', 'distances']
216
+ )
217
+
218
+ top_similar.append((main_text[idx], chunk, res, res['distances'][0]))
219
+
220
+ most_similar_items = sorted(top_similar, key=lambda x: x[3])[:top_similar_count]
221
+
222
+ logger.write("--------------- DONE -----------------\n")
223
+ log_area.text(logger.getvalue())
224
+
225
+ return most_similar_items
226
+
227
+
228
+ # Streamlit Interface
229
+ st.title("Check Infringement")
230
+
231
+ # Inputs
232
+ with st.sidebar:
233
+ st.header("Product Information")
234
+ main_product = st.text_input('Enter Main Product Name', 'Philips led 7w bulb')
235
+ main_url = st.text_input('Enter Main Product Manual URL', 'https://www.assets.signify.com/is/content/PhilipsConsumer/PDFDownloads/Colombia/technical-sheets/ODLI20180227_001-UPD-es_CO-Ficha_Tecnica_LED_MR16_Master_7W_Dim_12V_CRI90.pdf')
236
+
237
+ st.header("Search Settings")
238
+ search_method = st.selectbox('Choose Search Engine', ['All', 'duckduckgo', 'google', 'archive', 'github', 'wikipedia'])
239
+
240
+ product_count = st.number_input("Number of Similar Products", min_value=1, step=1, format="%i")
241
+ link_count = st.number_input("Number of Links per Product", min_value=1, step=1, format="%i")
242
+ need_image = st.selectbox("Process Images", ['True', 'False'])
243
+
244
+ top_similar_count = st.number_input("Top Similarities to be Displayed", value=3, min_value=1, step=1, format="%i")
245
+
246
+ if st.button('Check for Infringement'):
247
+ global log_output # Placeholder for log output
248
+
249
+ tab1, tab2 = st.tabs(["Output", "Console"])
250
+
251
+ with tab2:
252
+ log_output = st.empty()
253
+
254
+ with tab1:
255
+ with st.spinner('Processing...'):
256
+ with StreamCapture() as logger:
257
+ top_similar_values = score(main_product, main_url, product_count, link_count, search_method, logger, log_output)
258
+
259
+ st.success('Processing complete!')
260
+
261
+ st.subheader("Cosine Similarity Scores")
262
+
263
+ for main_text, main_vector, response, _ in top_similar_values:
264
+ product_name = response['metadatas'][0][0]['product_name']
265
+ link = response['metadatas'][0][0]['url']
266
+ similar_text = response['metadatas'][0][0]['text']
267
+
268
+ cosine_score = cosine_similarity([main_vector], response['embeddings'][0])[0][0]
269
+
270
+ # Display the product information
271
+ with st.container():
272
+ st.markdown(f"### [Product: {product_name}]({link})")
273
+ st.markdown(f"#### Cosine Score: {cosine_score:.4f}")
274
+ col1, col2 = st.columns(2)
275
+ with col1:
276
+ st.markdown(f"**Main Text:** {imporve_text(main_text)}")
277
+ with col2:
278
+ st.markdown(f"**Similar Text:** {imporve_text(similar_text)}")
279
+
280
+ st.markdown("---")
281
+
282
+ if need_image == 'True':
283
+ with st.spinner('Processing Images...'):
284
+ emb_main = get_image_embeddings(main_product)
285
+ similar_prod = extract_similar_products(main_product)[0]
286
+ emb_similar = get_image_embeddings(similar_prod)
287
+
288
+ similarity_matrix = np.zeros((5, 5))
289
+ for i in range(5):
290
+ for j in range(5):
291
+ similarity_matrix[i][j] = cosine_similarity([emb_main[i]], [emb_similar[j]])[0][0]
292
+
293
+ st.subheader("Image Similarity")
294
+ # Create an interactive heatmap
295
+ fig = px.imshow(similarity_matrix,
296
+ labels=dict(x=f"{similar_prod} Images", y=f"{main_product} Images", color="Similarity"),
297
+ x=[f"Image {i+1}" for i in range(5)],
298
+ y=[f"Image {i+1}" for i in range(5)],
299
+ color_continuous_scale="Viridis")
300
+
301
+ # Add title to the heatmap
302
+ fig.update_layout(title="Image Similarity Heatmap")
303
+
304
+ # Display the interactive heatmap
305
+ st.plotly_chart(fig)