|
import streamlit as st |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from transformers import AutoTokenizer, AutoModel |
|
import textract |
|
import docx2txt |
|
import pdfplumber |
|
|
|
def last_token_pool(last_hidden_states: Tensor, |
|
attention_mask: Tensor) -> Tensor: |
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
if left_padding: |
|
return last_hidden_states[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_states.shape[0] |
|
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
|
def get_detailed_instruct(task_description: str, query: str) -> str: |
|
return f'Instruct: {task_description}\nQuery: {query}' |
|
|
|
st.title("Text Similarity Model") |
|
|
|
task = 'Given a web search query, retrieve relevant passages that answer the query' |
|
|
|
docs = st.sidebar.file_uploader("Upload documents", accept_multiple_files=True, type=['txt','pdf','xlsx','docx']) |
|
query = st.text_input("Enter search query") |
|
click = st.button("Search") |
|
|
|
def extract_text(doc): |
|
if doc.type == 'text/plain': |
|
return doc.getvalue().decode("utf-8") |
|
|
|
if doc.type == "application/pdf": |
|
with pdfplumber.open(doc) as pdf: |
|
pages = [page.extract_text() for page in pdf.pages] |
|
return "\n".join(pages) |
|
|
|
if doc.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": |
|
return docx2txt.process(doc) |
|
|
|
if doc.name.endswith(".xlsx"): |
|
text = textract.process(doc) |
|
return text.decode("utf-8") |
|
|
|
return None |
|
|
|
|
|
if click and query: |
|
doc_contents = [] |
|
|
|
for doc in docs: |
|
|
|
doc_text = extract_text(doc) |
|
doc_contents.append(doc_text) |
|
|
|
doc_embeddings = get_embeddings(doc_contents) |
|
query_embedding = get_embedding(query) |
|
|
|
scores = compute_similarity(query_embedding, doc_embeddings) |
|
ranked_docs = get_ranked_docs(scores) |
|
|
|
st.write("Most Relevant Documents") |
|
for doc, score in ranked_docs: |
|
st.write(f"{doc.name} (score: {score:.2f})") |
|
|
|
|
|
|