Fact_Checking / app.py
robertselvam's picture
Update app.py
708e152 verified
raw
history blame
8.41 kB
import re
import os
import logging
import gradio as gr
from typing import Set, List, Tuple
from huggingface_hub import InferenceClient
from langchain_openai import AzureChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chains import SimpleSequentialChain
from langchain.chains import LLMSummarizationCheckerChain
# huggingface_key = os.getenv('HUGGINGFACE_KEY')
# print(huggingface_key)
# login(huggingface_key) # Huggingface api token
# Configure logging
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
class FactChecking:
def __init__(self):
self.llm = AzureChatOpenAI(
azure_deployment = "GPT-3"
)
self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
def format_prompt(self, question: str) -> str:
"""
Formats the input question into a specific structure for text generation.
Args:
question (str): The user's question to be formatted.
Returns:
str: The formatted prompt including instructions and the question.
"""
# Combine the instruction template with the user's question
prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]"
prompt1 = f"[INST] {question} [/INST]"
return prompt+prompt1
def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0):
"""
Generates a response to the given prompt using text generation parameters.
Args:
prompt (str): The user's question.
temperature (float): Controls randomness in response generation.
max_new_tokens (int): The maximum number of tokens to generate.
top_p (float): Nucleus sampling parameter controlling diversity.
repetition_penalty (float): Penalty for repeating tokens.
Returns:
str: The generated response to the input prompt.
"""
# Adjust temperature and top_p values within acceptable ranges
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
# Simulating a call to a client's text generation API
formatted_prompt =self.format_prompt(prompt)
stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return output.replace("</s>","")
def find_different_sentences(self,chain_answer,llm_answer):
try:
truth_values = [sentence.strip().split(' (')[1][:-1] for sentence in chain_answer.split('\n\n')]
except:
print("single new line presenting")
try:
# Extracting the truth values from chain_answer
truth_values = [sentence.strip().split(' (')[1][:-1] for sentence in chain_answer.split('\n')]
except:
print("two new lines presenting")
tags = []
for tag in truth_values:
if "True" in tag:
tags.append("factual")
else:
tags.append("hallucinated")
# Splitting llm_answer into sentences
llm_sentences = llm_answer.split('. ')
# Initializing an empty list to store tagged sentences
tagged_sentences = []
# Mapping the truth values to sentences in llm_answer
for sentence, truth_value in zip(llm_sentences, tags):
# Extracting the sentence without the truth value
sentence_text = sentence.split(' (')[0]
# Appending the sentence with its truth value
tagged_sentences.append(((sentence_text+"."),(truth_value)))
return tagged_sentences
def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]:
"""
Finds hallucinated sentences in response to a given question.
Args:
question (str): The input question.
Returns:
Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences.
"""
try:
# Generate initial response using contract generator
mixtral_response = self.mixtral_response(question)
template = """Given some text, extract a list of facts from the text.
Format your output as a bulleted list.
Text:
{question}
Facts:"""
prompt_template = PromptTemplate(input_variables=["question"], template=template)
question_chain = LLMChain(llm=self.llm, prompt=prompt_template)
template = """You are an expert fact checker. You have been hired by a major news organization to fact check a very important story.
Here is a bullet point list of facts:
{statement}
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined".
If the fact is false, explain why."""
prompt_template = PromptTemplate(input_variables=["statement"], template=template)
assumptions_chain = LLMChain(llm=self.llm, prompt=prompt_template)
extra_template = f" Original Summary:{mixtral_response} Using these checked assertions to write the original summary with true or false in sentence wised. For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output 'Undetermined'.***format: sentence (True or False) in braces.***"
template = """Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.
Checked Assertions:
{assertions}
"""
template += extra_template
prompt_template = PromptTemplate(input_variables=["assertions"], template=template)
answer_chain = LLMChain(llm=self.llm, prompt=prompt_template)
overall_chain = SimpleSequentialChain(chains=[question_chain,assumptions_chain,answer_chain], verbose=True)
answer = overall_chain.run(mixtral_response)
# Find different sentences between original result and fact checking result
prediction_list = self.find_different_sentences(answer,mixtral_response)
# prediction_list += generated_words
# Return the original result and list of hallucinated sentences
return mixtral_response,prediction_list,answer
except Exception as e:
print(f"Error occurred in find_hallucinatted_sentence: {e}")
return "", []
def interface(self):
css=""".gradio-container {background: rgb(157,228,255);
background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}"""
with gr.Blocks(css=css) as demo:
gr.HTML("""
<center><h1 style="color:#fff">Detect Hallucination</h1></center>""")
with gr.Row():
question = gr.Textbox(label="Question")
with gr.Row():
button = gr.Button(value="Submit")
with gr.Row():
mixtral_response = gr.Textbox(label="llm answer")
with gr.Row():
fact_checking_result = gr.Textbox(label="hallucinated detection result")
with gr.Row():
highlighted_prediction = gr.HighlightedText(
label="Sentence Hallucination detection",
combine_adjacent=True,
color_map={"hallucinated": "red", "factual": "green"},
show_legend=True)
button.click(self.find_hallucinatted_sentence,question,[mixtral_response,highlighted_prediction,fact_checking_result])
demo.launch()
hallucination_detection = FactChecking()
hallucination_detection.interface()