from email.utils import parseaddr | |
from huggingface_hub import HfApi | |
import os | |
import datetime | |
import pandas as pd | |
import json | |
import evaluate as nlp_evaluate | |
import re | |
import sqlite3 | |
import random | |
from tqdm import tqdm | |
import sys | |
import numpy as np | |
from get_exact_and_f1_score.ext_services.jsql_parser import JSQLParser | |
from get_exact_and_f1_score.metrics.partial_match_eval.evaluate import evaluate | |
random.seed(10001) | |
bleu = nlp_evaluate.load("bleu") | |
rouge = nlp_evaluate.load('rouge') | |
LEADERBOARD_PATH = "Exploration-Lab/BookSQL-Leaderboard" | |
RESULTS_PATH = "Exploration-Lab/BookSQL-Leaderboard-results" | |
api = HfApi() | |
TOKEN = os.environ.get("TOKEN", None) | |
YEAR_VERSION = "2024" | |
sqlite_path = "accounting/accounting_for_testing.sqlite" | |
_jsql_parser = JSQLParser.create() | |
def format_error(msg): | |
return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>" | |
def format_warning(msg): | |
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>" | |
def format_log(msg): | |
return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>" | |
def model_hyperlink(link, model_name): | |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>' | |
def input_verification(method_name, url, path_to_file, organisation, mail): | |
for input in [method_name, url, path_to_file, organisation, mail]: | |
if input == "": | |
return format_warning("Please fill all the fields.") | |
# Very basic email parsing | |
_, parsed_mail = parseaddr(mail) | |
if not "@" in parsed_mail: | |
return format_warning("Please provide a valid email adress.") | |
if path_to_file is None: | |
return format_warning("Please attach a file.") | |
return parsed_mail | |
def replace_current_date_and_now(_sql, _date): | |
_sql = _sql.replace('current_date', "\'"+_date+"\'") | |
_sql = _sql.replace(', now', ", \'"+_date+"\'") | |
return _sql | |
def remove_gold_Non_exec(data,df1, sqlite_path): | |
con = sqlite3.connect(sqlite_path) | |
cur = con.cursor() | |
out, non_exec=[], [] | |
new_df = df1.copy() | |
new_df.loc[:, 'Exec/Non-Exec'] = 0 | |
for i,s in tqdm(enumerate(data)): | |
_sql = str(s).replace('"', "'").lower() | |
_sql = replace_current_date_and_now(_sql, '2022-06-01') | |
_sql = replace_percent_symbol_y(_sql) | |
try: | |
cur.execute(_sql) | |
res = cur.fetchall() | |
out.append(i) | |
except: | |
non_exec.append(i) | |
print("_sql: ", _sql) | |
new_df.loc[out, 'Exec/Non-Exec'] = 1 | |
con.close() | |
return out, non_exec, new_df | |
def remove_data_from_index(data, ind_list): | |
new_data=[] | |
for i in ind_list: | |
new_data.append(data[i]) | |
return new_data | |
def get_exec_match_acc(gold, pred): | |
assert len(gold)==len(pred) | |
count=0 | |
goldd = [re.sub(' +', ' ', str(g).replace("'", '"').lower()) for g in gold] | |
predd = [re.sub(' +', ' ', str(p).replace("'", '"').lower()) for p in pred] | |
# for g, p in zip(gold, pred): | |
# #extra space, double quotes, lower_case | |
# gg = re.sub(' +', ' ', str(g).replace("'", '"').lower()) | |
# gg = re.sub(' +', ' ', str(p).replace("'", '"').lower()) | |
# if gold==pred: | |
# count+=1 | |
goldd = _jsql_parser.translate_batch(goldd) | |
predd = _jsql_parser.translate_batch(predd) | |
pcm_f1_scores = evaluate(goldd, predd) | |
pcm_em_scores = evaluate(goldd, predd, exact_match=True) | |
_pcm_f1_scores, _pcm_em_scores=[], [] | |
for f1, em in zip(pcm_f1_scores, pcm_em_scores): | |
if type(f1)==float and type(em)==float: | |
_pcm_f1_scores.append(f1) | |
_pcm_em_scores.append(em) | |
assert len(_pcm_f1_scores) == len(_pcm_em_scores) | |
jsql_error_count=0 ####JSQLError | |
for i, score in enumerate(pcm_f1_scores): | |
if type(score)==str: | |
jsql_error_count+=1 | |
print("JSQLError in sql: ", jsql_error_count) | |
return sum(_pcm_em_scores) / len(_pcm_em_scores), sum(_pcm_f1_scores) / len(_pcm_f1_scores) | |
def replace_percent_symbol_y(_sql): | |
_sql = _sql.replace('%y', "%Y") | |
return _sql | |
def get_exec_results(sqlite_path, scores, df, flag, gold_sql_map_res={}): | |
con = sqlite3.connect(sqlite_path) | |
cur = con.cursor() | |
i,j,count=0,0,0 | |
out,non_exec={},{} | |
new_df = df.copy() | |
responses=[] | |
for s in tqdm(scores): | |
_sql = str(s).replace('"', "'").lower() | |
_sql = replace_current_date_and_now(_sql, '2022-06-01') | |
_sql = replace_percent_symbol_y(_sql) | |
try: | |
cur.execute(_sql) | |
res = cur.fetchall() | |
out[i] = str(res) | |
except Exception as err: | |
non_exec[i]=err | |
i+=1 | |
if flag=='g': | |
new_df.loc[list(out.keys()), 'GOLD_res'] = list(out.values()) | |
# assert len(gold_sql_map_res)==count | |
if flag=='p': | |
new_df.loc[list(out.keys()), 'PRED_res'] = list(out.values()) | |
if flag=='d': | |
new_df.loc[list(out.keys()), 'DEBUG_res'] = list(out.values()) | |
con.close() | |
return out, non_exec, new_df | |
def get_scores(gold_dict, pred_dict): | |
exec_count, non_exec_count=0, 0 | |
none_count=0 | |
correct_sql, incorrect_sql = [], [] | |
for k, res in pred_dict.items(): | |
if k in gold_dict: | |
if gold_dict[k]==str(None) or str(None) in gold_dict[k]: | |
none_count+=1 | |
continue | |
if res==gold_dict[k]: | |
exec_count+=1 | |
correct_sql.append(k) | |
else: | |
non_exec_count+=1 | |
incorrect_sql.append(k) | |
return exec_count, non_exec_count, none_count, correct_sql, incorrect_sql | |
def get_total_gold_none_count(gold_dict): | |
none_count, ok_count=0, 0 | |
for k, res in gold_dict.items(): | |
if res==str(None) or str(None) in res: | |
none_count+=1 | |
else: ok_count+=1 | |
return ok_count, none_count | |
def evaluate(df): | |
# df - [id, pred_sql] | |
pred_sql = df['pred_sql'].to_list() | |
ids = df['id'].to_list() | |
f = open(f"tests/test.json") | |
questions_and_ids = json.load(f) | |
ts = open(f"tests/test_sql.json") | |
gold_sql = json.load(ts) | |
gold_sql_list=[] | |
pred_sql_list=[] | |
questions_list=[] | |
for idx, pred in zip(ids, pred_sql): | |
ques = questions_and_ids[idx]['Query'] | |
gd_sql = gold_sql[idx]['SQL'] | |
gold_sql_list.append(gd_sql) | |
pred_sql_list.append(pred_sql_list) | |
questions_list.append(ques) | |
df = pd.DataFrame({'NLQ':questions_list, 'GOLD SQL':gold_sql_list, 'PREDICTED SQL':pred_sql_list}) | |
test_size = len(df) | |
pred_score = df['PREDICTED SQL'].str.lower().values | |
# debug_score = df['DEBUGGED SQL'].str.lower().values | |
gold_score1 = df['GOLD SQL'].str.lower().values | |
print("Checking non-exec Gold sql query") | |
gold_exec, gold_not_exec, new_df = remove_gold_Non_exec(gold_score1, df, sqlite_path) | |
print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size)) | |
print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size)) | |
prev_non_exec_df = new_df[new_df['Exec/Non-Exec'] == 0] | |
new_df = new_df[new_df['Exec/Non-Exec']==1] | |
prev_non_exec_df.reset_index(inplace=True) | |
new_df.reset_index(inplace=True) | |
#Removing Non-exec sql from data | |
print(f"Removing {len(gold_not_exec)} non-exec sql query from all Gold/Pred/Debug") | |
gold_score1 = remove_data_from_index(gold_score1, gold_exec) | |
pred_score = remove_data_from_index(pred_score, gold_exec) | |
# debug_score = remove_data_from_index(debug_score, gold_exec) | |
gold_score = [[x] for x in gold_score1] | |
assert len(gold_score) == len(pred_score) #== len(debug_score) | |
pred_bleu_score = bleu.compute(predictions=pred_score, references=gold_score) | |
pred_rouge_score = rouge.compute(predictions=pred_score, references=gold_score) | |
pred_exact_match, pred_partial_f1_score = get_exec_match_acc(gold_score1, pred_score) | |
print("PREDICTED_vs_GOLD Final bleu_score: ", pred_bleu_score['bleu']) | |
print("PREDICTED_vs_GOLD Final rouge_score: ", pred_rouge_score['rougeL']) | |
print("PREDICTED_vs_GOLD Exact Match Accuracy: ", pred_exact_match) | |
print("PREDICTED_vs_GOLD Partial CM F1 score: ", pred_partial_f1_score) | |
print() | |
new_df.loc[:, 'GOLD_res'] = str(None) | |
new_df.loc[:, 'PRED_res'] = str(None) | |
# new_df.loc[:, 'DEBUG_res'] = str(None) | |
print("Getting Gold results") | |
# gout_res_dict, gnon_exec_err_dict, gold_sql_map_res = get_exec_results(cur, gold_score1, 'g') | |
gout_res_dict, gnon_exec_err_dict, new_df = get_exec_results(sqlite_path, gold_score1, new_df, 'g') | |
total_gold_ok_count, total_gold_none_count = get_total_gold_none_count(gout_res_dict) | |
print("Total Gold None count: ", total_gold_none_count) | |
print("Getting Pred results") | |
pout_res_dict, pnon_exec_err_dict, new_df = get_exec_results(sqlite_path, pred_score, new_df, 'p') | |
# print("Getting Debug results") | |
# dout_res_dict, dnon_exec_err_dict = get_exec_results(cur, debug_score, 'd') | |
print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size)) | |
print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size)) | |
print() | |
print("PRED Total exec SQL query: {}/{}".format(len(pout_res_dict), len(pred_score))) | |
print("PRED Total non-exec SQL query: {}/{}".format(len(pnon_exec_err_dict), len(pred_score))) | |
print() | |
# print("DEBUG Total exec SQL query: {}/{}".format(len(dout_res_dict), len(debug_score))) | |
# print("DEBUG Total non-exec SQL query: {}/{}".format(len(dnon_exec_err_dict), len(debug_score))) | |
# print() | |
pred_correct_exec_acc_count, pred_incorrect_exec_acc_count, pred_none_count, pred_correct_sql, pred_incorrect_sql = get_scores(gout_res_dict, pout_res_dict) | |
# debug_correct_exec_acc_count, debug_incorrect_exec_acc_count, debug_none_count, debug_correct_sql, debug_incorrect_sql = get_scores(gout_res_dict, dout_res_dict) | |
# print("PRED_vs_GOLD None_count: ", total_gold_none_count) | |
print("PRED_vs_GOLD Correct_Exec_count without None: ", pred_correct_exec_acc_count) | |
print("PRED_vs_GOLD Incorrect_Exec_count without None: ", pred_incorrect_exec_acc_count) | |
print("PRED_vs_GOLD Exec_Accuracy: ", pred_correct_exec_acc_count/total_gold_ok_count) | |
print() | |
return pred_exact_match, pred_correct_exec_acc_count/total_gold_ok_count, pred_partial_f1_score, pred_bleu_score['bleu'], pred_rouge_score['rougeL'] | |
def add_new_eval( | |
method_name: str, | |
url: str, | |
path_to_file: str, | |
organisation: str, | |
mail: str, | |
): | |
parsed_mail = input_verification( | |
method_name, | |
url, | |
path_to_file, | |
organisation, | |
mail, | |
) | |
# load the file | |
df = pd.read_csv(path_to_file) | |
submission_df = pd.read_csv(path_to_file) | |
# modify the df to include metadata | |
df["Method"] = method_name | |
df["url"] = url | |
df["organisation"] = organisation | |
df["mail"] = parsed_mail | |
df["timestamp"] = | |
submission_df = pd.read_csv(path_to_file) | |
submission_df["Method"] = method_name | |
submission_df["Submitted By"] = organisation | |
# upload to spaces using the hf api at | |
path_in_repo = f"submissions/{method_name}" | |
file_name = f"{method_name}-{organisation}-{'%Y-%m-%d')}.csv" | |
EM, EX, PCM_F1, BLEU, ROUGE = evaluate(submission_df) | |
submission_df['EM'] = EM | |
submission_df['EX'] = EX | |
# submission_df['PCM_F1'] = PCM_F1 | |
submission_df['BLEU'] = BLEU | |
submission_df['ROUGE'] = ROUGE | |
# upload the df to spaces | |
import io | |
buffer = io.BytesIO() | |
df.to_csv(buffer, index=False) # Write the DataFrame to a buffer in CSV format | | # Rewind the buffer to the beginning | |
api.upload_file( | |
repo_id=RESULTS_PATH, | |
path_in_repo=f"{path_in_repo}/{file_name}", | |
path_or_fileobj=buffer, | |
token=TOKEN, | |
repo_type="dataset", | |
) | |
# read the leaderboard | |
leaderboard_df = pd.read_csv(f"submissions/baseline/baseline.csv") | |
# append the new submission_df csv to the leaderboard | |
# leaderboard_df = leaderboard_df._append(submission_df) | |
leaderboard_df = pd.concat([leaderboard_df, submission_df], ignore_index=True) | |
# save the new leaderboard | |
# leaderboard_df.to_csv(f"submissions/baseline/baseline.csv", index=False) | |
leaderboard_buffer = io.BytesIO() | |
leaderboard_df.to_csv(leaderboard_buffer, index=False) | | | |
api.upload_file( | |
path_in_repo=f"submissions/baseline/baseline.csv", | |
path_or_fileobj=leaderboard_buffer, | |
token=TOKEN, | |
repo_type="space", | |
) | |
return format_log( | |
f"Method {method_name} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed" | |
) | |