|
from step1_api_claim_extractor import ClaimExtractor |
|
from step2_api_fix_passage_anchors import FixAnchors |
|
from step3_api_identify_objective_claims import ClassifyClaims |
|
from step41_api_fetch_cohere_wikipedia_evidence import CohereEvidence |
|
from step42_api_fetch_google_search_evidence import GoogleEvidence |
|
from step5_api_embed_search_results import EmbedResults |
|
from step6_api_claims_to_evidence import ClaimToEvidence |
|
from step7_api_check_claims_against_evidence import CheckClaimAgainstEvidence |
|
from step8_api_format_fact_checked_document import FormatDocument |
|
|
|
import argparse |
|
import json |
|
import os |
|
import copy |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
def get_fact_checked(text_input, model="gpt-3.5-turbo", mode="slow"): |
|
text_input = text_input.strip() |
|
|
|
results = {} |
|
|
|
|
|
print("Step1: Extracting claims") |
|
step1 = ClaimExtractor(model=model) |
|
step1_json = step1.extract_claims(text_input) |
|
results["step1_claims"] = copy.deepcopy(step1_json) |
|
|
|
|
|
print("Step2: Anchoring claims") |
|
try: |
|
step2 = FixAnchors(model=model) |
|
step2_json = step2.fix_passage_anchors(step1_json, text_input) |
|
except: |
|
if model != "gpt-4": |
|
print("Step2 failed with gpt-3.5, trying with gpt-4!") |
|
step2 = FixAnchors(model="gpt-4") |
|
step2_json = step2.fix_passage_anchors(step1_json, text_input) |
|
results["step2_anchored_claims"] = copy.deepcopy(step2_json) |
|
|
|
|
|
print("Step3: Classifying claims") |
|
step3 = ClassifyClaims(model=model) |
|
step3_json = step3.classify_claims(step2_json) |
|
step3_filter = step3.filter_to_objective_claims(step3_json) |
|
results["step3_classify_claims"] = copy.deepcopy(step3_json) |
|
results["step3_objective_claims"] = copy.deepcopy(step3_filter) |
|
|
|
if len(step3_filter) == 0: |
|
return {"fact_checked_md": "No objective claims found!"} |
|
|
|
|
|
print("Step4.1: Gathering evidence") |
|
step4_cohere = CohereEvidence() |
|
step4_json_cohere = ( |
|
step4_cohere.fetch_cohere_semantic_search_results_to_gather_evidence( |
|
step3_filter |
|
) |
|
) |
|
results["step41_cohere_evidence"] = copy.deepcopy(step4_json_cohere) |
|
|
|
|
|
print("Step4.2: Gathering evidence") |
|
step4_json_google = None |
|
if mode == "slow": |
|
step4_json_google = "" |
|
try: |
|
step4_google = GoogleEvidence(model=model) |
|
step4_json_google = step4_google.fetch_search_results_to_gather_evidence( |
|
step3_filter |
|
) |
|
except Exception as e: |
|
print(f"Google search failed: {e}") |
|
pass |
|
results["step42_google_evidence"] = copy.deepcopy(step4_json_google) |
|
|
|
embedding_model = "text-embedding-ada-002" |
|
text_embedding_chunk_size = 500 |
|
|
|
srcs = [step4_json_cohere] |
|
if step4_json_google: |
|
srcs.append(step4_json_google) |
|
|
|
|
|
print("Step5: Embedding evidence") |
|
step5 = EmbedResults( |
|
embedding_model=embedding_model, |
|
text_embedding_chunk_size=text_embedding_chunk_size, |
|
) |
|
faiss_db = step5.embed_for_uuid(srcs) |
|
|
|
|
|
print("Step6: Linking claims to evidence") |
|
step6 = ClaimToEvidence() |
|
step6_json = step6.link_claims_to_evidence(step3_filter, faiss_db) |
|
results["step6_claim_to_evidence"] = copy.deepcopy(step6_json) |
|
|
|
|
|
print("Step7: Checking claims against evidence") |
|
step7 = CheckClaimAgainstEvidence(model=model) |
|
step7_json = step7.check_claims_against_evidence(step6_json) |
|
results["step7_evaluated_claims"] = copy.deepcopy(step7_json) |
|
|
|
|
|
print("Step8: Formatting") |
|
step8 = FormatDocument(model=model, footnote_style="verbose") |
|
step8_md = step8.reformat_document_to_include_claims( |
|
text_input, step7_json, footnote_style="verbose" |
|
) |
|
step8_md_terse = step8.reformat_document_to_include_claims( |
|
text_input, step7_json, footnote_style="terse" |
|
) |
|
|
|
results["fact_checked_md"] = copy.deepcopy(step8_md) |
|
results["fact_checked_terse"] = copy.deepcopy(step8_md_terse) |
|
return results |
|
|
|
|
|
def main(args): |
|
with open(args.file, "r") as f: |
|
text = f.read() |
|
out = get_fact_checked(text, mode="slow", model=args.model) |
|
print(out["fact_checked_md"]) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Process a file.") |
|
parser.add_argument("--file", type=str, help="File to process", required=True) |
|
parser.add_argument("--model", type=str, help="Model to use", required=True) |
|
args = parser.parse_args() |
|
main(args) |
|
|