BounWiki / run.py
LeoGitGuy
reqs
f8b79d2
raw
history blame
5.75 kB
# imports
import logging
#import torch_scatter
import argparse
import pickle
import csv
import collections
import itertools
import copy
from setup_database import get_document_store, add_data
from setup_modules import create_retriever, create_readers_and_pipeline, text_reader_types, table_reader_types
from eval_helper import create_labels
def parse_args():
parser = argparse.ArgumentParser(description="JointXplore")
parser.add_argument("--context", help='which information should be added as context, subset of [processed_website_tables, processed_website_text, processed_schedule_tables], enter as multiple strings',
nargs='+', default=["processed_website_tables","processed_website_text","processed_schedule_tables"])
parser.add_argument("--text_reader", help="specify the model to use as text reader", choices=["minilm", "distilroberta", "electra-base", "bert-base", "deberta-large", "gpt3"], default="bert-base")
parser.add_argument("--api-key", help="if gpt3 choosen as reader, please provide api-key", action="store_true")
parser.add_argument("--table_reader", help="choose tapas or convert table to text file and treat them as such", choices=["tapas", "text"], default="tapas")
parser.add_argument("--seperate_evaluation", help="if specified, student generated questions and synthetically generated questions are evaluated seperately", action="store_true")
args = parser.parse_args()
# if 'LOCAL_RANK' not in os.environ:
# os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main(*args):
if args=={}:
args = parse_args()
filenames = args.context
text_reader = args.text_reader
table_reader = args.table_reader
seperate_evaluation = args.seperate_evaluation
else:
filenames, text_reader, table_reader, seperate_evaluation = args
print(f"Filenames: {filenames}")
use_table = False
use_text = False
if "processed_schedule_tables" in filenames:
use_table = True
if "processed_website_text" or "processed_website_tables" in filenames:
use_text = True
# configure logger
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.WARNING)
logging.info("Starting..")
document_index = "document"
document_store = get_document_store(document_index)
print(f"Number of docs previously: {len(document_store.get_all_documents())}")
document_store, data = add_data(filenames, document_store, document_index)
print(f"Number of docs after: {len(document_store.get_all_documents())}")
document_store, retriever = create_retriever(document_store)
text_reader_type = text_reader_types[text_reader]
table_reader_type = table_reader_types[table_reader]
pipeline = create_readers_and_pipeline(retriever, text_reader_type, table_reader_type, use_table, use_text)
with open("./output/results.csv", "r") as f:
reader = csv.reader(f)
for header in reader:
break
labels_file = "./data/validation_data/processed_qa.json"
labels = create_labels(labels_file, data, seperate_evaluation)
label_types = ["all_eval"]
if seperate_evaluation:
label_types = ["students", "synthetic"]
for idx, label in enumerate(labels):
# for la in label:
# for l in la.labels:
# print("CHECK WRONG DOC")
# print(l.document.content == "")
# print(l.document.id)
print(f"Label Dataset: {idx}")
results = pipeline.eval(label, params={"top_k": 10}, sas_model_name_or_path="cross-encoder/stsb-roberta-large")
res_dict = results.calculate_metrics()
print(res_dict)
with open(f"./output/{text_reader}_{table_reader}_{('_').join(filenames)}_{label_types[idx]}", "wb") as fp:
pickle.dump(results, fp)
exp_dict = {
"Text Reader": text_reader,
"Table Reader": table_reader,
"Context" : ('_').join(filenames),
"Label type": label_types[idx]
}
if 'JoinAnswers' in res_dict:
csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['JoinAnswers'], **exp_dict}
elif 'TableReader' in res_dict:
csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TableReader'], **exp_dict}
elif 'TextReader' in res_dict:
csv_dict_new = {**res_dict['EmbeddingRetriever'], **res_dict['TextReader'], **exp_dict}
if idx == 1:
csv_dict_all = {}
# iterating key, val with chain()
total_num_samples = csv_dict["num_examples_for_eval"] + csv_dict_new["num_examples_for_eval"]
weight_old = csv_dict["num_examples_for_eval"]
weight_new = csv_dict_new["num_examples_for_eval"]
print("Weights for datasets:", weight_old, weight_new)
print("new")
for key, val in csv_dict.items():
if not isinstance(val, str):
if key != "num_examples_for_eval":
csv_dict_all[key] = ((val*weight_old + csv_dict_new[key]*weight_new)/total_num_samples)
else:
csv_dict_all[key] = (val + csv_dict_new[key])
else:
csv_dict_all[key] = val
csv_dict_all["Label type"] = "all_eval"
with open("./output/results.csv", "a", newline='') as f:
writer = csv.DictWriter(f, fieldnames=header)
writer.writerow(csv_dict_all)
csv_dict = copy.deepcopy(csv_dict_new)
print(csv_dict)
with open("./output/results.csv", "a", newline='') as f:
writer = csv.DictWriter(f, fieldnames=header)
writer.writerow(csv_dict)
document_store.delete_index(document_index)
if __name__ == "__main__":
main()