#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by zd302 at 17/07/2024
from fastapi import FastAPI
from pydantic import BaseModel
# from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
import uvicorn
import spaces
app = FastAPI()
# ---------------------------------------------------------------------------------------------------------------------
import gradio as gr
import tqdm
import torch
import numpy as np
from time import sleep
from datetime import datetime
import threading
import gc
import os
import json
import pytorch_lightning as pl
from urllib.parse import urlparse
from accelerate import Accelerator
import spaces
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from rank_bm25 import BM25Okapi
# import bm25s
# import Stemmer # optional: for stemming
from html2lines import url2lines
from googleapiclient.discovery import build
from averitec.models.DualEncoderModule import DualEncoderModule
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
from averitec.data.sample_claims import CLAIMS_Type
# ---------------------------------------------------------------------------
# load .env
from utils import create_user_id
user_id = create_user_id()
from azure.storage.fileshare import ShareServiceClient
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
# ---------------------------------------------------------------------------
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
account_url = os.environ["AZURE_ACCOUNT_URL"]
credential = {
"account_key": os.environ['AZURE_ACCOUNT_KEY'],
"account_name": os.environ['AZURE_ACCOUNT_NAME']
}
file_share_name = "averitec"
azure_service = ShareServiceClient(account_url=account_url, credential=credential)
azure_share_client = azure_service.get_share_client(file_share_name)
# ---------------------------------------------------------------------------------------------------------------------
import requests
from bs4 import BeautifulSoup
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
import nltk
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk import pos_tag, word_tokenize, sent_tokenize
import spacy
os.system("python -m spacy download en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# ---------------------------------------------------------------------------
train_examples = json.load(open('averitec/data/train.json', 'r'))
def claim2prompts(example):
claim = example["claim"]
# claim_str = "Claim: " + claim + "||Evidence: "
claim_str = "Evidence: "
for question in example["questions"]:
q_text = question["question"].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
for a in question["answers"]:
if a["answer_type"] in ["Extractive", "Abstractive"]:
answer_strings.append(a["answer"])
if a["answer_type"] == "Boolean":
answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip())
for a_text in answer_strings:
if not a_text[-1] in [".", "!", ":", "?"]:
a_text += "."
# prompt_lookup_str = claim + " " + a_text
prompt_lookup_str = a_text
this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text
yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n"))
def generate_reference_corpus(reference_file):
all_data_corpus = []
tokenized_corpus = []
for train_example in train_examples:
train_claim = train_example["claim"]
speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len(
train_example["speaker"]) > 1 else "they"
questions = [q["question"] for q in train_example["questions"]]
claim_dict_builder = {}
claim_dict_builder["claim"] = train_claim
claim_dict_builder["speaker"] = speaker
claim_dict_builder["questions"] = questions
tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"]))
all_data_corpus.append(claim_dict_builder)
return tokenized_corpus, all_data_corpus
def generate_step2_reference_corpus(reference_file):
prompt_corpus = []
tokenized_corpus = []
for example in train_examples:
for lookup_str, prompt in claim2prompts(example):
entry = nltk.word_tokenize(lookup_str)
tokenized_corpus.append(entry)
prompt_corpus.append(prompt)
return tokenized_corpus, prompt_corpus
reference_file = "averitec/data/train.json"
tokenized_corpus0, all_data_corpus0 = generate_reference_corpus(reference_file)
qg_bm25 = BM25Okapi(tokenized_corpus0)
tokenized_corpus1, prompt_corpus1 = generate_step2_reference_corpus(reference_file)
prompt_bm25 = BM25Okapi(tokenized_corpus1)
# ---------------------------------------------------------------------------------------------------------------------
# ---------- Setting ----------
# ---------- Load Veracity and Justification prediction model ----------
print("Loading models ...")
LABEL = [
"Supported",
"Refuted",
"Not Enough Evidence",
"Conflicting Evidence/Cherrypicking",
]
if torch.cuda.is_available():
# question generation
qg_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-1b1")
qg_model = BloomForCausalLM.from_pretrained("bigscience/bloom-1b1", torch_dtype=torch.bfloat16).to('cuda')
# rerank
rerank_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
rereank_bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason
best_checkpoint = "averitec/pretrained_models/bert_dual_encoder.ckpt"
rerank_trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=rerank_tokenizer, model=rereank_bert_model)
# Veracity
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model)
# Justification
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model)
# ---------------------------------------------------------------------------
# ----------------------------------------------------------------------------
class Docs:
def __init__(self, metadata=dict(), page_content=""):
self.metadata = metadata
self.page_content = page_content
# ------------------------------ Googleretriever -----------------------------
def doc2prompt(doc):
prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[
"claim"].strip() + "\". Criticism includes questions like: "
questions = [q.strip() for q in doc["questions"]]
return prompt_parts + " ".join(questions)
def docs2prompt(top_docs):
return "\n\n".join([doc2prompt(d) for d in top_docs])
@spaces.GPU
def prompt_question_generation(test_claim, speaker="they", topk=10):
# --------------------------------------------------
# test claim
s = qg_bm25.get_scores(nltk.word_tokenize(test_claim))
top_n = np.argsort(s)[::-1][:topk]
docs = [all_data_corpus0[i] for i in top_n]
# --------------------------------------------------
prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \
"\". Criticism includes questions like: "
sentences = [prompt]
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
outputs = qg_model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
tgt_text = qg_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
in_len = len(sentences[0])
questions_str = tgt_text[in_len:].split("\n")[0]
qs = questions_str.split("?")
qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300]
#
generate_question = [{"question": q, "answers": []} for q in qs]
return generate_question
def check_claim_date(check_date):
try:
year, month, date = check_date.split("-")
except:
month, date, year = "01", "01", "2022"
if len(year) == 2 and int(year) <= 30:
year = "20" + year
elif len(year) == 2:
year = "19" + year
elif len(year) == 1:
year = "200" + year
if len(month) == 1:
month = "0" + month
if len(date) == 1:
date = "0" + date
sort_date = year + month + date
return sort_date
def string_to_search_query(text, author):
parts = word_tokenize(text.strip())
tags = pos_tag(parts)
keep_tags = ["CD", "JJ", "NN", "VB"]
if author is not None:
search_string = author.split()
else:
search_string = []
for token, tag in zip(parts, tags):
for keep_tag in keep_tags:
if tag[1].startswith(keep_tag):
search_string.append(token)
search_string = " ".join(search_string)
return search_string
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
search_results = []
for i in range(1):
try:
search_results += google_search(
search_string,
api_key,
search_engine_id,
num=3, # num=10,
start=0 + 10 * page,
sort="date:r:19000101:" + sort_date,
dateRestrict=None,
gl="US"
)
break
except:
sleep(1)
return search_results
def google_search(search_term, api_key, cse_id, **kwargs):
service = build("customsearch", "v1", developerKey=api_key)
res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
if "items" in res:
return res['items']
else:
return []
def get_domain_name(url):
if '://' not in url:
url = 'http://' + url
domain = urlparse(url).netloc
if domain.startswith("www."):
return domain[4:]
else:
return domain
def get_text_from_link(url_link):
page_lines = url2lines(url_link)
return "\n".join([url_link] + page_lines)
def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
# default config
api_key = os.environ["GOOGLE_API_KEY"]
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
blacklist = [
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
"nlp.cs.princeton.edu",
"huggingface.co"
]
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
"/glove.",
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
"https://web.mit.edu/adamrose/Public/googlelist",
]
# save to folder
store_folder = "averitec/data/store/retrieved_docs"
#
index = 0
questions = [q["question"] for q in generate_question][:3]
# questions = [q["question"] for q in generate_question] # ori
# check the date of the claim
current_date = datetime.now().strftime("%Y-%m-%d")
sort_date = check_claim_date(current_date) # check_date="2022-01-01"
#
search_strings = []
search_types = []
search_string_2 = string_to_search_query(claim, None)
search_strings += [search_string_2, claim, ]
search_types += ["claim", "claim-noformat", ]
search_strings += questions
search_types += ["question" for _ in questions]
# start to search
search_results = []
visited = {}
store_counter = 0
worker_stack = list(range(10))
retrieve_evidence = []
for this_search_string, this_search_type in zip(search_strings, search_types):
for page_num in range(n_pages):
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
this_search_string, page=page_num)
for result in search_results:
link = str(result["link"])
domain = get_domain_name(link)
if domain in blacklist:
continue
broken = False
for b_file in blacklist_files:
if b_file in link:
broken = True
if broken:
continue
if link.endswith(".pdf") or link.endswith(".doc"):
continue
store_file_path = ""
if link in visited:
web_text = visited[link]
else:
web_text = get_text_from_link(link)
visited[link] = web_text
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text]
retrieve_evidence.append(line)
return retrieve_evidence
@spaces.GPU
def decorate_with_questions(claim, retrieve_evidence, top_k=3): # top_k=5, 10, 100
#
tokenized_corpus = []
all_data_corpus = []
for retri_evi in tqdm.tqdm(retrieve_evidence):
# store_file = retri_evi[-1]
# with open(store_file, 'r') as f:
web_text = retri_evi[-1]
lines_in_web = web_text.split("\n")
first = True
for line in lines_in_web:
# for line in f:
line = line.strip()
if first:
first = False
location_url = line
continue
if len(line) > 3:
entry = nltk.word_tokenize(line)
if (location_url, line) not in all_data_corpus:
tokenized_corpus.append(entry)
all_data_corpus.append((location_url, line))
if len(tokenized_corpus) == 0:
print("")
bm25 = BM25Okapi(tokenized_corpus)
s = bm25.get_scores(nltk.word_tokenize(claim))
top_n = np.argsort(s)[::-1][:top_k]
docs = [all_data_corpus[i] for i in top_n]
generate_qa_pairs = []
# Then, generate questions for those top 50:
for doc in tqdm.tqdm(docs):
# prompt_lookup_str = example["claim"] + " " + doc[1]
prompt_lookup_str = doc[1]
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
prompt_n = 10
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
prompt_docs = [prompt_corpus1[i] for i in prompt_top_n]
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
prompt = "\n\n".join(prompt_docs + [claim_prompt])
sentences = [prompt]
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(qg_model.device)
outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True)
tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
# We are not allowed to generate more than 250 characters:
tgt_text = tgt_text[:250]
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
generate_qa_pairs.append(qa_pair)
return generate_qa_pairs
def triple_to_string(x):
return " ".join([item.strip() for item in x])
@spaces.GPU
def rerank_questions(claim, bm25_qas, topk=3):
#
strs_to_score = []
values = []
for question, answer, source in bm25_qas:
str_to_score = triple_to_string([claim, question, answer])
strs_to_score.append(str_to_score)
values.append([question, answer, source])
if len(bm25_qas) > 0:
encoded_dict = rerank_tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(rerank_trained_model.device)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
scores = torch.softmax(rerank_trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1]
top_n = torch.argsort(scores, descending=True)[:topk]
pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n]
else:
pass_through = []
top3_qa_pairs = pass_through
return top3_qa_pairs
@spaces.GPU
def Googleretriever(query):
# ----- Generate QA pairs using AVeriTeC
# step 1: generate questions for the query/claim using Bloom
generate_question = prompt_question_generation(query)
# step 2: retrieve evidence for the generated questions using Google API
retrieve_evidence = averitec_search(query, generate_question)
# step 3: generate QA pairs for each retrieved document
bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence)
# step 4: rerank QA pairs
top3_qa_pairs = rerank_questions(query, bm25_qa_pairs)
# Add score to metadata
results = []
for i, qa in enumerate(top3_qa_pairs):
metadata = dict()
metadata['name'] = qa['question']
metadata['url'] = qa['source_url']
metadata['cached_source_url'] = qa['source_url']
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['title'] = qa['question']
metadata['evidence'] = qa['answers']
metadata['query'] = qa['question']
metadata['answer'] = qa['answers']
metadata['page_content'] = "Question: " + qa['question'] + "
" + "Answer: " + qa['answers']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ------------------------------ Googleretriever -----------------------------
# ------------------------------ Wikipediaretriever --------------------------
def search_entity_wikipeida(entity):
find_evidence = []
page_py = wiki_wiki.page(entity)
if page_py.exists():
introduction = page_py.summary
find_evidence.append([str(entity), introduction])
return find_evidence
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def find_similar_wikipedia(entity, relevant_wikipages):
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
ent_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs:
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
similar_titles = result_titles[:5]
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
for _t in similar_titles:
if _t not in saved_titles and len(relevant_wikipages) < 5:
_evi = search_entity_wikipeida(_t)
# _evi = search_step(_t)
relevant_wikipages.extend(_evi)
return relevant_wikipages
def find_evidence_from_wikipedia(claim):
#
doc = nlp(claim)
#
wikipedia_page = []
for ent in doc.ents:
relevant_wikipages = search_entity_wikipeida(ent)
if len(relevant_wikipages) < 5:
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
wikipedia_page.extend(relevant_wikipages)
return wikipedia_page
def bm25_retriever(query, corpus, topk=3):
bm25 = BM25Okapi(corpus)
#
query_tokens = word_tokenize(query)
scores = bm25.get_scores(query_tokens)
top_n = np.argsort(scores)[::-1][:topk]
top_n_scores = [scores[i] for i in top_n]
return top_n, top_n_scores
def relevant_sentence_retrieval(query, wiki_intro, k):
# 1. Create corpus here
corpus, sentences = [], []
titles = []
for i, (title, intro) in enumerate(wiki_intro):
sents_in_intro = sent_tokenize(intro)
for sent in sents_in_intro:
corpus.append(word_tokenize(sent))
sentences.append(sent)
titles.append(title)
# ----- BM25
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
return bm25_top_n_sents, bm25_top_n_titles
# ------------------------------ Wikipediaretriever -----------------------------
def Wikipediaretriever(claim):
# 1. extract relevant wikipedia pages from wikipedia dumps
wikipedia_page = find_evidence_from_wikipedia(claim)
# 2. extract relevant sentences from extracted wikipedia pages
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
#
results = []
for i, (sent, title) in enumerate(zip(sents, titles)):
metadata = dict()
metadata['name'] = claim
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
metadata['short_name'] = "Evidence {}".format(i + 1)
metadata['page_number'] = ""
metadata['query'] = sent
metadata['title'] = title
metadata['evidence'] = sent
metadata['answer'] = ""
metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence']
page_content = f"""{metadata['page_content']}"""
results.append(Docs(metadata, page_content))
return results
# ------------------------------ Veracity Prediction ------------------------------
class SequenceClassificationDataLoader(pl.LightningDataModule):
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
super().__init__()
self.tokenizer = tokenizer
self.data_file = data_file
self.batch_size = batch_size
self.add_extra_nee = add_extra_nee
def tokenize_strings(
self,
source_sentences,
max_length=400,
pad_to_max_length=False,
return_tensors="pt",
):
encoded_dict = self.tokenizer(
source_sentences,
max_length=max_length,
padding="max_length" if pad_to_max_length else "longest",
truncation=True,
return_tensors=return_tensors,
)
input_ids = encoded_dict["input_ids"]
attention_masks = encoded_dict["attention_mask"]
return input_ids, attention_masks
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
if bool_explanation is not None and len(bool_explanation) > 0:
bool_explanation = ", because " + bool_explanation.lower().strip()
else:
bool_explanation = ""
return (
"[CLAIM] "
+ claim.strip()
+ " [QUESTION] "
+ question.strip()
+ " "
+ answer.strip()
+ bool_explanation
)
@spaces.GPU
def veracity_prediction(claim, evidence):
dataLoader = SequenceClassificationDataLoader(
tokenizer=veracity_tokenizer,
data_file="this_is_discontinued",
batch_size=32,
add_extra_nee=False,
)
evidence_strings = []
for evi in evidence:
evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
pred_label = "Not Enough Evidence"
return pred_label
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
example_support = torch.argmax(veracity_model(tokenized_strings.to(veracity_model.device), attention_mask=attention_mask.to(veracity_model.device)).logits, axis=1)
# example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
has_unanswerable = False
has_true = False
has_false = False
for v in example_support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
has_unanswerable = True
if has_unanswerable:
answer = 2
elif has_true and not has_false:
answer = 0
elif not has_true and has_false:
answer = 1
else:
answer = 3
pred_label = LABEL[answer]
return pred_label
# ------------------------------ Justification Generation ------------------------------
def extract_claim_str(claim, evidence, verdict_label):
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evi in evidence:
q_text = evi.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evi.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
return claim_str
@spaces.GPU
def justification_generation(claim, evidence, verdict_label):
#
# claim_str = extract_claim_str(claim, evidence, verdict_label)
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
for evi in evidence:
q_text = evi.metadata['query'].strip()
if len(q_text) == 0:
continue
if not q_text[-1] == "?":
q_text += "?"
answer_strings = []
answer_strings.append(evi.metadata['answer'])
claim_str += q_text
for a_text in answer_strings:
if a_text:
if not a_text[-1] == ".":
a_text += "."
claim_str += " " + a_text.strip()
claim_str += " "
claim_str += " [VERDICT] " + verdict_label
#
claim_str.strip()
pred_justification = justification_model.generate(claim_str, device=justification_model.device)
# pred_justification = justification_model.generate(claim_str, device=device)
return pred_justification.strip()
# ---------------------------------------------------------------------------------------------------------------------
class Item(BaseModel):
claim: str
source: str
@app.get("/")
@spaces.GPU
def greet_json():
return {"Hello": "World!"}
def log_on_azure(file, logs, azure_share_client):
logs = json.dumps(logs)
file_client = azure_share_client.get_file_client(file)
file_client.upload_file(logs)
@app.post("/predict/")
@spaces.GPU
def fact_checking(item: Item):
# claim = item['claim']
# source = item['source']
claim = item.claim
source = item.source
# Step1: Evidence Retrieval
if source == "Wikipedia":
evidence = Wikipediaretriever(claim)
elif source == "Google":
evidence = Googleretriever(claim)
# Step2: Veracity Prediction and Justification Generation
verdict_label = veracity_prediction(claim, evidence)
justification_label = justification_generation(claim, evidence, verdict_label)
############################################################
evidence_list = []
for evi in evidence:
title_str = evi.metadata['title']
evi_str = evi.metadata['evidence']
url_str = evi.metadata['url']
evidence_list.append([title_str, evi_str, url_str])
try:
# Log answer on Azure Blob Storage
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
if os.environ["AZURE_ISSAVE"] == "TRUE":
timestamp = str(datetime.now().timestamp())
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
file = timestamp + ".json"
logs = {
"user_id": str(user_id),
"claim": claim,
"sources": source,
"evidence": evidence_list,
"answer": [verdict_label, justification_label],
"time": timestamp,
}
log_on_azure(file, logs, azure_share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
raise gr.Error(
f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
##########
return {"Verdict": verdict_label, "Justification": justification_label, "Evidence": evidence_list}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# if __name__ == "__main__":
# item = {
# "claim": "England won the Euro 2024.",
# "source": "Google", # Google, Wikipedia
# }
#
# results = fact_checking(item)
#
# print(results)