Spaces:
Running
Running
import nest_asyncio | |
import streamlit as st | |
import os | |
from flashrank import Ranker, RerankRequest | |
from qdrant_client import QdrantClient | |
from llama_index.llms.groq import Groq | |
from llama_index.core.retrievers import VectorIndexRetriever | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from llama_index.core.retrievers import QueryFusionRetriever | |
from llama_index.core import VectorStoreIndex | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
from llama_index.core import Settings | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
import PyPDF2 | |
nest_asyncio.apply() | |
os.environ["HF_TOKEN"] = st.secrets["HF_TOKEN"] | |
groq_token = st.secrets["groq_token"] | |
st.set_page_config( | |
layout="wide" | |
) | |
# default llamaindex llm and embendding model selection | |
def llamaindex_default(): | |
Settings.llm = Groq(model="llama-3.1-8b-instant", api_key=groq_token) | |
Settings.embed_model = HuggingFaceEmbedding( | |
model_name="law-ai/InLegalBERT", trust_remote_code=True | |
) | |
llamaindex_default() | |
# set up qdrant client | |
def load_index(): | |
qdrant_client = QdrantClient( | |
path="." | |
) | |
vector_store = QdrantVectorStore( | |
client=qdrant_client, collection_name="legal_v1" | |
) | |
return VectorStoreIndex.from_vector_store(vector_store=vector_store) | |
index = load_index() | |
# reranker selection in the sidebar | |
with st.sidebar: | |
selected_reranker = st.selectbox( | |
"Select a reranker", | |
("default", "ms-marco-MiniLM-L-12-v2", "rank-T5-flan") | |
) | |
if selected_reranker == "default": | |
ranker = Ranker() | |
else: | |
ranker = Ranker(model_name=selected_reranker, cache_dir=".") | |
# Calculate individual weightages with sidebar slider | |
dense_weightage = st.slider("Dense Weightage", min_value=0.0, max_value=1.0, value=0.5, step=0.1) | |
sparse_weightage = 1 - dense_weightage | |
st.write("dense weight: ",dense_weightage) | |
st.write("sparse weight: ",sparse_weightage) | |
num_k = st.number_input( | |
"Enter k", | |
min_value=1, | |
max_value=10, | |
value=10 | |
) | |
def load_retriver(): | |
dense_retriever = VectorIndexRetriever( | |
index=index, | |
similarity_top_k=num_k | |
) | |
sparse_retriever = BM25Retriever.from_persist_dir("./sparse_retriever") | |
sparse_retriever.similarity_top_k = num_k | |
retriever = QueryFusionRetriever( | |
[ | |
dense_retriever, | |
sparse_retriever, | |
], | |
num_queries=1, | |
use_async=False, | |
retriever_weights=[dense_weightage, sparse_weightage], | |
similarity_top_k=num_k, | |
mode="relative_score", | |
verbose=True, | |
) | |
return retriever | |
retriever = load_retriver() | |
def extract_pdf_content(pdf_file_path): | |
with open(pdf_file_path, 'rb') as pdf_file: | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
text = "" | |
for page_num in range(len(pdf_reader.pages)): | |
page = pdf_reader.pages[page_num] | |
text += page.extract_text() | |
return text | |
#prompt template for summarization | |
template = """ | |
Please summarize the following legal document and provide the summary in the specified format. The output should directly follow the format without any introductory text. | |
**Document:** | |
{document_content} | |
**Format:** | |
**Case:** [Case Number] | |
**Petitioner:** [Petitioner's Name] | |
**Respondent:** [Respondent's Name] | |
**Judge:** [Judge's Name] | |
**Order Date:** [Order Date] | |
**Summary:** | |
- **Background:** [Brief description of the case background] | |
- **Allegations:** [Summary of the allegations made in the case] | |
- **Investigation:** [Key findings from the investigation] | |
- **Court's Decision:** [Summary of the court's decision and any conditions imposed] | |
""" | |
st.title("Legal Documents Hybrid+Reranker Search") | |
query = st.text_input("Search through documents by keyword", value="") | |
search_btn = st.button("Search") | |
if search_btn and query: | |
nodes = retriever.retrieve(query) | |
passages = [] | |
for node in nodes: | |
passage = { | |
"id": node.node_id, | |
"text": node.text, | |
"meta": node.metadata | |
} | |
passages.append(passage) | |
rerankrequest = RerankRequest(query=query, passages=passages) | |
results = ranker.rerank(rerankrequest) | |
for node in results: | |
st.write("File Name: ", node["meta"].get("file_name")) | |
st.write("reranking score: ", node["score"]) | |
st.write("node id", node["id"]) | |
with st.expander("See Summary"): | |
text = extract_pdf_content("./documents/" + node["meta"].get("file_name")) | |
formatted_template = template.format(document_content=text) | |
summary = Settings.llm.complete(formatted_template) | |
st.markdown(summary) | |
st.write("---") | |