harshil1973's picture
let's see
4dc9e48
raw
history blame
4.83 kB
import nest_asyncio
import streamlit as st
import os
from flashrank import Ranker, RerankRequest
from qdrant_client import QdrantClient
from llama_index.llms.groq import Groq
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import PyPDF2
nest_asyncio.apply()
os.environ["HF_TOKEN"] = st.secrets["HF_TOKEN"]
groq_token = st.secrets["groq_token"]
st.set_page_config(
layout="wide"
)
# default llamaindex llm and embendding model selection
@st.cache_resource(show_spinner=False)
def llamaindex_default():
Settings.llm = Groq(model="llama-3.1-8b-instant", api_key=groq_token)
Settings.embed_model = HuggingFaceEmbedding(
model_name="law-ai/InLegalBERT", trust_remote_code=True
)
llamaindex_default()
# set up qdrant client
@st.cache_resource(show_spinner=False)
def load_index():
qdrant_client = QdrantClient(
path="."
)
vector_store = QdrantVectorStore(
client=qdrant_client, collection_name="legal_v1"
)
return VectorStoreIndex.from_vector_store(vector_store=vector_store)
index = load_index()
# reranker selection in the sidebar
with st.sidebar:
selected_reranker = st.selectbox(
"Select a reranker",
("default", "ms-marco-MiniLM-L-12-v2", "rank-T5-flan")
)
if selected_reranker == "default":
ranker = Ranker()
else:
ranker = Ranker(model_name=selected_reranker, cache_dir=".")
# Calculate individual weightages with sidebar slider
dense_weightage = st.slider("Dense Weightage", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
sparse_weightage = 1 - dense_weightage
st.write("dense weight: ",dense_weightage)
st.write("sparse weight: ",sparse_weightage)
num_k = st.number_input(
"Enter k",
min_value=1,
max_value=10,
value=10
)
@st.cache_resource(show_spinner=False)
def load_retriver():
dense_retriever = VectorIndexRetriever(
index=index,
similarity_top_k=num_k
)
sparse_retriever = BM25Retriever.from_persist_dir("./sparse_retriever")
sparse_retriever.similarity_top_k = num_k
retriever = QueryFusionRetriever(
[
dense_retriever,
sparse_retriever,
],
num_queries=1,
use_async=False,
retriever_weights=[dense_weightage, sparse_weightage],
similarity_top_k=num_k,
mode="relative_score",
verbose=True,
)
return retriever
retriever = load_retriver()
def extract_pdf_content(pdf_file_path):
with open(pdf_file_path, 'rb') as pdf_file:
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text()
return text
#prompt template for summarization
template = """
Please summarize the following legal document and provide the summary in the specified format. The output should directly follow the format without any introductory text.
**Document:**
{document_content}
**Format:**
**Case:** [Case Number]
**Petitioner:** [Petitioner's Name]
**Respondent:** [Respondent's Name]
**Judge:** [Judge's Name]
**Order Date:** [Order Date]
**Summary:**
- **Background:** [Brief description of the case background]
- **Allegations:** [Summary of the allegations made in the case]
- **Investigation:** [Key findings from the investigation]
- **Court's Decision:** [Summary of the court's decision and any conditions imposed]
"""
st.title("Legal Documents Hybrid+Reranker Search")
query = st.text_input("Search through documents by keyword", value="")
search_btn = st.button("Search")
if search_btn and query:
nodes = retriever.retrieve(query)
passages = []
for node in nodes:
passage = {
"id": node.node_id,
"text": node.text,
"meta": node.metadata
}
passages.append(passage)
rerankrequest = RerankRequest(query=query, passages=passages)
results = ranker.rerank(rerankrequest)
for node in results:
st.write("File Name: ", node["meta"].get("file_name"))
st.write("reranking score: ", node["score"])
st.write("node id", node["id"])
with st.expander("See Summary"):
text = extract_pdf_content("./documents/" + node["meta"].get("file_name"))
formatted_template = template.format(document_content=text)
summary = Settings.llm.complete(formatted_template)
st.markdown(summary)
st.write("---")