File size: 2,428 Bytes
011c6b2
 
0c7ffdb
 
 
2524123
79ecc72
 
 
c208ca1
6ed8967
0c7ffdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f52a963
 
 
 
 
 
 
 
 
 
 
 
 
79ecc72
 
 
0c7ffdb
f52a963
0ee4a85
60eae40
0ee4a85
 
60eae40
9bbbf26
f52a963
faa2e50
958bbd7
f52a963
faa2e50
958bbd7
faa2e50
958bbd7
 
0ee4a85
 
 
 
 
60eae40
79ecc72
 
0c7ffdb
79ecc72
 
 
 
0c7ffdb
79ecc72
 
011c6b2
79ecc72
 
011c6b2
79ecc72
 
 
 
 
0c7ffdb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import tempfile
import textract
import docx2txt
import pdfplumber
import io
import os

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

st.title("Text Similarity Model")

task = 'Given a web search query, retrieve relevant passages that answer the query'  


UPLOAD_DIR = "uploads" 

if not os.path.exists(UPLOAD_DIR):
    os.mkdir(UPLOAD_DIR)

def save_upload(uploaded_file):
    filepath = os.path.join(UPLOAD_DIR, uploaded_file.name)
    with open(filepath,"wb") as f: 
        f.write(uploaded_file.getbuffer())
        
    return filepath

docs = st.sidebar.file_uploader("Upload documents", accept_multiple_files=True, type=['txt','pdf','xlsx','docx']) 
query = st.text_input("Enter search query")
click = st.button("Search")



def extract_text(doc):
    if doc.type == 'text/plain':
        return doc.read().decode('utf-8')

    if doc.name.endswith(".pdf"):
        docPath = save_upload(doc)
        

        with pdfplumber.open(docPath) as pdf:
            pages = [page.extract_text() for page in pdf.pages]

            return "\n".join(pages)
    
    
    if doc.name.endswith('.docx'):
        raw_text = doc.read()
        return docx2txt.process(raw_text)

    return None

if click and query:
    doc_contents = []
    
    for doc in docs:
        # Extract text from each document
        doc_text = extract_text(doc)  
        doc_contents.append(doc_text)

    doc_embeddings = get_embeddings(doc_contents)  
    query_embedding = get_embedding(query)
    
    scores = compute_similarity(query_embedding, doc_embeddings)
    ranked_docs = get_ranked_docs(scores)
    
    st.write("Most Relevant Documents")
    for doc, score in ranked_docs:
        st.write(f"{doc.name} (score: {score:.2f})")