File size: 2,213 Bytes
011c6b2
 
0c7ffdb
 
 
79ecc72
 
 
0c7ffdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79ecc72
 
 
0c7ffdb
60eae40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79ecc72
 
0c7ffdb
79ecc72
 
 
 
0c7ffdb
79ecc72
 
011c6b2
79ecc72
 
011c6b2
79ecc72
 
 
 
 
0c7ffdb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import textract
import docx2txt
import pdfplumber

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

st.title("Text Similarity Model")

task = 'Given a web search query, retrieve relevant passages that answer the query'  

docs = st.sidebar.file_uploader("Upload documents", accept_multiple_files=True, type=['txt','pdf','xlsx','docx']) 
query = st.text_input("Enter search query")
click = st.button("Search")

def extract_text(doc):
    if doc.type == 'text/plain':
        return doc.getvalue().decode("utf-8")
        
    if doc.type == "application/pdf":
        with pdfplumber.open(doc) as pdf:
            pages = [page.extract_text() for page in pdf.pages]
            return "\n".join(pages)
        
    if doc.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
        return docx2txt.process(doc)
        
    if doc.name.endswith(".xlsx"):
        text = textract.process(doc)
        return text.decode("utf-8")

    return None
    

if click and query:
    doc_contents = []
    
    for doc in docs:
        # Extract text from each document
        doc_text = extract_text(doc)  
        doc_contents.append(doc_text)

    doc_embeddings = get_embeddings(doc_contents)  
    query_embedding = get_embedding(query)
    
    scores = compute_similarity(query_embedding, doc_embeddings)
    ranked_docs = get_ranked_docs(scores)
    
    st.write("Most Relevant Documents")
    for doc, score in ranked_docs:
        st.write(f"{doc.name} (score: {score:.2f})")