Spaces:
Runtime error
Runtime error
from typing import List, Optional | |
import torch | |
import streamlit as st | |
import pandas as pd | |
import random | |
import time | |
import logging | |
from json import JSONDecodeError | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
from haystack import Document | |
from haystack.document_stores import FAISSDocumentStore | |
from haystack.modeling.utils import initialize_device_settings | |
from haystack.nodes import EmbeddingRetriever | |
from haystack.pipelines import Pipeline | |
from haystack.nodes.base import BaseComponent | |
from haystack.schema import Document | |
from Project.Fact_Checking_Blue_Amazon.config import ( | |
RETRIEVER_TOP_K, | |
RETRIEVER_MODEL, | |
NLI_MODEL, | |
) | |
class EntailmentChecker(BaseComponent): | |
""" | |
This node checks the entailment between every document content and the statement. | |
It enrichs the documents metadata with entailment informations. | |
It also returns aggregate entailment information. | |
""" | |
outgoing_edges = 1 | |
def __init__( | |
self, | |
model_name_or_path: str = "roberta-large-mnli", | |
model_version: Optional[str] = None, | |
tokenizer: Optional[str] = None, | |
use_gpu: bool = True, | |
batch_size: int = 100, | |
entailment_contradiction_consideration: float = 0.6, | |
entailment_contradiction_threshold: float = 0.8 | |
): | |
""" | |
Load a Natural Language Inference model from Transformers. | |
:param model_name_or_path: Directory of a saved model or the name of a public model. | |
See https://huggingface.co/models for full list of available models. | |
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. | |
:param tokenizer: Name of the tokenizer (usually the same as model) | |
:param use_gpu: Whether to use GPU (if available). | |
:param batch_size: Number of Documents to be processed at a time. | |
:param entailment_contradiction_threshold: Only consider sentences that have entailment or contradiction score greater than this param. | |
""" | |
super().__init__() | |
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) | |
tokenizer = tokenizer or model_name_or_path | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
pretrained_model_name_or_path=model_name_or_path, revision=model_version | |
) | |
self.batch_size = batch_size | |
self.entailment_contradiction_threshold = entailment_contradiction_threshold | |
self.entailment_contradiction_consideration = entailment_contradiction_consideration | |
self.model.to(str(self.devices[0])) | |
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label | |
self.labels = [id2label[k].lower() for k in sorted(id2label)] | |
if "entailment" not in self.labels: | |
raise ValueError("The model config must contain entailment value in the id2label dict.") | |
def run(self, query: str, documents: List[Document]): | |
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 | |
premise_batch = [doc.content for doc in documents] | |
hypothesis_batch = [query] * len(documents) | |
entailment_info_batch = self.get_entailment_batch( | |
premise_batch=premise_batch, hypothesis_batch=hypothesis_batch | |
) | |
considered_documents = [] | |
for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)): | |
doc.meta["entailment_info"] = entailment_info | |
con, neu, ent = ( | |
entailment_info["contradiction"], | |
entailment_info["neutral"], | |
entailment_info["entailment"], | |
) | |
if (con > self.entailment_contradiction_consideration) or (ent > self.entailment_contradiction_consideration): | |
considered_documents.append(doc) | |
agg_con += con | |
agg_neu += neu | |
agg_ent += ent | |
scores += 1 | |
if max(agg_con, agg_ent)/scores > self.entailment_contradiction_threshold: | |
break | |
# if in the first documents there is a strong evidence of entailment/contradiction, | |
# there is no need to consider less relevant documents | |
aggregate_entailment_info = { | |
"contradiction": round(agg_con / scores, 2), | |
"neutral": round(agg_neu / scores, 2), | |
"entailment": round(agg_ent / scores, 2), | |
} | |
entailment_checker_result = { | |
"documents": considered_documents[: i + 1], | |
"aggregate_entailment_info": aggregate_entailment_info, | |
} | |
return entailment_checker_result | |
def get_entailment_dict(self, probs): | |
return {k.lower(): v for k, v in zip(self.labels, probs)} | |
def get_entailment_batch(self, premise_batch: List[str], hypothesis_batch: List[str]): | |
formatted_texts = [ | |
f"{premise}{self.tokenizer.sep_token}{hypothesis}" | |
for premise, hypothesis in zip(premise_batch, hypothesis_batch) | |
] | |
with torch.inference_mode(): | |
inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to( | |
self.devices[0] | |
) | |
out = self.model(**inputs) | |
logits = out.logits | |
probs_batch = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() | |
return [self.get_entailment_dict(probs) for probs in probs_batch] | |
# cached to make index and models load only at start | |
def start_haystack(): | |
""" | |
load document store, retriever, entailment checker and create pipeline | |
""" | |
document_store = FAISSDocumentStore( | |
faiss_index_path=f"../data//my_faiss_index.faiss", | |
faiss_config_path=f"../data/my_faiss_index.json", | |
) | |
print(f"Index size: {document_store.get_document_count()}") | |
retriever = EmbeddingRetriever( | |
document_store=document_store, | |
embedding_model=RETRIEVER_MODEL | |
) | |
entailment_checker = EntailmentChecker( | |
model_name_or_path=NLI_MODEL, | |
use_gpu=False, | |
) | |
pipe = Pipeline() | |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"]) | |
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"]) | |
return pipe | |
pipe = start_haystack() | |
def check_statement(pipe, statement: str, retriever_top_k: int = 5): | |
"""Run query and verify statement""" | |
params = {"retriever": {"top_k": retriever_top_k}} | |
return pipe.run(statement, params=params) | |
def set_state_if_absent(key, value): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
# Small callback to reset the interface in case the text of the question changes | |
def reset_results(*args): | |
st.session_state.answer = None | |
st.session_state.results = None | |
st.session_state.raw_json = None | |
def create_df_for_relevant_snippets(docs): | |
""" | |
Create a dataframe that contains all relevant snippets. | |
""" | |
rows = [] | |
for doc in docs: | |
row = { | |
"Content": doc.content, | |
"con": f"{doc.meta['entailment_info']['contradiction']:.2f}", | |
"neu": f"{doc.meta['entailment_info']['neutral']:.2f}", | |
"ent": f"{doc.meta['entailment_info']['entailment']:.2f}", | |
} | |
rows.append(row) | |
df = pd.DataFrame(rows) | |
df["Content"] = df["Content"].str.wrap(75) | |
df = df.style.apply(highlight_cols) | |
return df | |
def highlight_cols(s): | |
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"} | |
if s.name in coldict.keys(): | |
return ["background-color: {}".format(coldict[s.name])] * len(s) | |
return [""] * len(s) | |
def main(): | |
# Persistent state | |
set_state_if_absent("statement", "") | |
set_state_if_absent("answer", "") | |
set_state_if_absent("results", None) | |
set_state_if_absent("raw_json", None) | |
st.write("# Verificação de Sentenças sobre Amazônia Azul") | |
st.write() | |
st.markdown( | |
""" | |
##### Insira uma sentença sobre a amazônia azul. | |
""" | |
) | |
# Search bar | |
statement = st.text_input( | |
"", value=st.session_state.statement, max_chars=100, on_change=reset_results | |
) | |
st.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) | |
run_pressed = st.button("Run") | |
run_query = ( | |
run_pressed or statement != st.session_state.statement | |
) | |
# Get results for query | |
if run_query and statement: | |
time_start = time.time() | |
reset_results() | |
st.session_state.statement = statement | |
with st.spinner(" Procurando a Similaridade no banco de sentenças..."): | |
try: | |
st.session_state.results = check_statement(statement, RETRIEVER_TOP_K) | |
print(f"S: {statement}") | |
time_end = time.time() | |
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) | |
print(f"elapsed time: {time_end - time_start}") | |
except JSONDecodeError as je: | |
st.error( | |
"👓 Erro na document store." | |
) | |
return | |
except Exception as e: | |
logging.exception(e) | |
st.error("🐞 Erro Genérico.") | |
return | |
# Display results | |
if st.session_state.results: | |
docs = st.session_state.results["documents"] | |
agg_entailment_info = st.session_state.results["aggregate_entailment_info"] | |
st.markdown(f"###### Aggregate entailment information:") | |
st.write(agg_entailment_info) | |
st.markdown(f"###### Most Relevant snippets:") | |
df = create_df_for_relevant_snippets(docs) | |
st.dataframe(df) | |
main() |