luiscgp's picture
minor
afc2679
raw
history blame
11 kB
from typing import List, Optional
import torch
import streamlit as st
import pandas as pd
import random
import time
import logging
import shutil
from json import JSONDecodeError
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.modeling.utils import initialize_device_settings
from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import Pipeline
from haystack.nodes.base import BaseComponent
from haystack.schema import Document
from config import (
RETRIEVER_TOP_K,
RETRIEVER_MODEL,
NLI_MODEL,
)
class EntailmentChecker(BaseComponent):
"""
This node checks the entailment between every document content and the statement.
It enrichs the documents metadata with entailment informations.
It also returns aggregate entailment information.
"""
outgoing_edges = 1
def __init__(
self,
model_name_or_path: str = "roberta-large-mnli",
model_version: Optional[str] = None,
tokenizer: Optional[str] = None,
use_gpu: bool = True,
batch_size: int = 100,
entailment_contradiction_consideration: float = 0.6,
entailment_contradiction_threshold: float = 0.8
):
"""
Load a Natural Language Inference model from Transformers.
:param model_name_or_path: Directory of a saved model or the name of a public model.
See https://huggingface.co/models for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: Whether to use GPU (if available).
:param batch_size: Number of Documents to be processed at a time.
:param entailment_contradiction_threshold: Only consider sentences that have entailment or contradiction score greater than this param.
"""
super().__init__()
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
tokenizer = tokenizer or model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version
)
self.batch_size = batch_size
self.entailment_contradiction_threshold = entailment_contradiction_threshold
self.entailment_contradiction_consideration = entailment_contradiction_consideration
self.model.to(str(self.devices[0]))
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
self.labels = [id2label[k].lower() for k in sorted(id2label)]
if "entailment" not in self.labels:
raise ValueError("The model config must contain entailment value in the id2label dict.")
def run(self, query: str, documents: List[Document]):
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
premise_batch = [doc.content for doc in documents]
hypothesis_batch = [query] * len(documents)
entailment_info_batch = self.get_entailment_batch(
premise_batch=premise_batch, hypothesis_batch=hypothesis_batch
)
considered_documents = []
for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)):
doc.meta["entailment_info"] = entailment_info
con, neu, ent = (
entailment_info["contradiction"],
entailment_info["neutral"],
entailment_info["entailment"],
)
if (con > self.entailment_contradiction_consideration) or (ent > self.entailment_contradiction_consideration):
considered_documents.append(doc)
agg_con += con
agg_neu += neu
agg_ent += ent
scores += 1
if max(agg_con, agg_ent)/scores > self.entailment_contradiction_threshold:
break
# if in the first documents there is a strong evidence of entailment/contradiction,
# there is no need to consider less relevant documents
aggregate_entailment_info = {
"contradiction": round(agg_con / scores, 2),
"neutral": round(agg_neu / scores, 2),
"entailment": round(agg_ent / scores, 2),
}
entailment_checker_result = {
"documents": considered_documents[: i + 1],
"aggregate_entailment_info": aggregate_entailment_info,
}
return entailment_checker_result
def run_batch(self, queries: List[str], documents: List[Document]):
entailment_checker_result_batch = []
entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypothesis_batch=queries)
for doc, entailment_info in zip(documents, entailment_info_batch):
doc.meta["entailment_info"] = entailment_info
aggregate_entailment_info = {
"contradiction": round(entailment_info["contradiction"] / doc.score),
"neutral": round(entailment_info["neutral"] / doc.score),
"entailment": round(entailment_info["entailment"] / doc.score),
}
entailment_checker_result_batch.append(
{
"documents": [doc],
"aggregate_entailment_info": aggregate_entailment_info,
}
)
return entailment_checker_result_batch, "output_1"
def get_entailment_dict(self, probs):
return {k.lower(): v for k, v in zip(self.labels, probs)}
def get_entailment_batch(self, premise_batch: List[str], hypothesis_batch: List[str]):
formatted_texts = [
f"{premise}{self.tokenizer.sep_token}{hypothesis}"
for premise, hypothesis in zip(premise_batch, hypothesis_batch)
]
with torch.inference_mode():
inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to(
self.devices[0]
)
out = self.model(**inputs)
logits = out.logits
probs_batch = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy()
return [self.get_entailment_dict(probs) for probs in probs_batch]
# cached to make index and models load only at start
@st.cache_resource
def start_haystack():
"""
load document store, retriever, entailment checker and create pipeline
"""
shutil.copy("./data/faiss_document_store.db", ".")
document_store = FAISSDocumentStore(
faiss_index_path=f"./data/my_faiss_index.faiss",
faiss_config_path=f"./data/my_faiss_index.json",
)
print(f"Index size: {document_store.get_document_count()}")
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=RETRIEVER_MODEL
)
entailment_checker = EntailmentChecker(
model_name_or_path=NLI_MODEL,
use_gpu=False,
)
pipe = Pipeline()
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
return pipe
def run_statement(statement: str, retriever_top_k: int = 5):
pipe = start_haystack()
run = check_statement(statement, retriever_top_k)
return run
@st.cache_resource
def check_statement(pipe, statement: str, retriever_top_k: int = 5):
"""Run query and verify statement"""
params = {"retriever": {"top_k": retriever_top_k}}
return pipe.run(statement, params=params)
def set_state_if_absent(key, value):
if key not in st.session_state:
st.session_state[key] = value
# Small callback to reset the interface in case the text of the question changes
def reset_results(*args):
st.session_state.answer = None
st.session_state.results = None
st.session_state.raw_json = None
def create_df_for_relevant_snippets(docs):
"""
Create a dataframe that contains all relevant snippets.
"""
rows = []
for doc in docs:
row = {
"Content": doc.content,
"con": f"{doc.meta['entailment_info']['contradiction']:.2f}",
"neu": f"{doc.meta['entailment_info']['neutral']:.2f}",
"ent": f"{doc.meta['entailment_info']['entailment']:.2f}",
}
rows.append(row)
df = pd.DataFrame(rows)
df["Content"] = df["Content"].str.wrap(75)
df = df.style.apply(highlight_cols)
return df
def highlight_cols(s):
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
if s.name in coldict.keys():
return ["background-color: {}".format(coldict[s.name])] * len(s)
return [""] * len(s)
def main():
# Persistent state
set_state_if_absent("statement", "")
set_state_if_absent("answer", "")
set_state_if_absent("results", None)
set_state_if_absent("raw_json", None)
st.write("# Verificação de Sentenças sobre Amazônia Azul")
st.write()
st.markdown(
"""
##### Insira uma sentença sobre a amazônia azul.
"""
)
# Search bar
statement = st.text_input(
"", value=st.session_state.statement, max_chars=100, on_change=reset_results
)
st.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
run_pressed = st.button("Run")
run_query = (
run_pressed or statement != st.session_state.statement
)
# Get results for query
if run_query and statement:
time_start = time.time()
reset_results()
st.session_state.statement = statement
with st.spinner("&nbsp;&nbsp; Procurando a Similaridade no banco de sentenças..."):
try:
st.session_state.results = run_statement(statement, RETRIEVER_TOP_K)
print(f"S: {statement}")
time_end = time.time()
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
print(f"elapsed time: {time_end - time_start}")
except JSONDecodeError as je:
st.error(
"👓 &nbsp;&nbsp; Erro na document store."
)
return
except Exception as e:
logging.exception(e)
st.error("🐞 &nbsp;&nbsp; Erro Genérico.")
return
# Display results
if st.session_state.results:
docs = st.session_state.results["documents"]
agg_entailment_info = st.session_state.results["aggregate_entailment_info"]
st.markdown(f"###### Aggregate entailment information:")
st.write(agg_entailment_info)
st.markdown(f"###### Most Relevant snippets:")
df = create_df_for_relevant_snippets(docs)
st.dataframe(df)
main()