Spaces:
Runtime error
Runtime error
from logging import error | |
from datasets import load_dataset | |
import transformers | |
from random import sample | |
import random | |
import torch | |
import json | |
from tqdm import tqdm | |
from nltk.translate.bleu_score import sentence_bleu | |
import pandas as pd | |
import re | |
''' | |
data format | |
{text_a, text_b, label:None or 0_1, } | |
''' | |
DATASET_HUGGINGFACE = { | |
'cnndm': ['cnn_dailymail', '3.0.0', 'train'], | |
'mnli': ['multi_nli', 'default', 'train'], | |
'squad': ['squad', 'plain_text', 'train'], | |
'squad_v2': ['squad_v2', 'squad_v2', 'train'], | |
'paws': ['paws', 'labeled_final', 'train'], | |
'vitaminc': ['tals/vitaminc', 'v1.0', 'train'], | |
'xsum': ['xsum', 'default', 'train'], | |
'stsb': ['glue', 'stsb', 'train'], | |
'sick': ['sick', 'default', 'train'], | |
'race': ['race', 'all', 'train'], | |
'race_val': ['race', 'all', 'validation'], | |
'anli_r1': ['anli', 'plain_text', 'train_r1'], | |
'anli_r2': ['anli', 'plain_text', 'train_r2'], | |
'anli_r3': ['anli', 'plain_text', 'train_r3'], | |
'snli': ['snli', 'plain_text', 'train'], | |
'wikihow': ['wikihow', 'all', 'train'], | |
'mrpc': ['glue', 'mrpc', 'train'], | |
'msmarco': ['ms_marco', 'v2.1', 'train'], | |
'mrpc_val': ['glue', 'mrpc', 'validation'], | |
'paws_val': ['paws', 'labeled_final', 'validation'], | |
'paws_unlabeled': ['paws', 'unlabeled_final', 'train'], | |
'qqp': ['glue', 'qqp', 'train'], | |
'qqp_val': ['glue', 'qqp', 'validation'], | |
'squad_v2_new': ['squad_v2', 'squad_v2', 'train'], | |
'adversarial_qa': ['adversarial_qa', 'adversarialQA', 'train'], | |
'drop': ['drop', 'train'], | |
'duorc_self': ['duorc', 'SelfRC', 'train'], | |
'duorc_paraphrase': ['duorc', 'ParaphraseRC', 'train'], | |
'quoref': ['quoref', 'train'], | |
'hotpot_qa_distractor': ['hotpot_qa', 'distractor', 'train'], | |
'hotpot_qa_fullwiki': ['hotpot_qa', 'fullwiki', 'train'], | |
'ropes': ['ropes', 'train'], | |
'boolq': ['boolq', 'train'], | |
'eraser_multi_rc': ['eraser_multi_rc', 'train'], | |
'quail': ['quail', 'train'], | |
'sciq': ['sciq', 'train'], | |
'strategy_qa': ['metaeval/strategy-qa', 'train'], | |
'gap': ['gap', 'train'], | |
} | |
DATASET_CONFIG = { | |
'cnndm': {'task': 'summarization', 'text_a': 'article', 'text_b': 'highlights', 'label': None, 'huggingface': True}, | |
'mnli': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True}, | |
'nli_fever': {'task': 'fact_checking', 'text_a': 'context', 'text_b': 'query', 'label': 'label','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/nli_fever/train_fitems.jsonl' }, | |
'doc_nli': {'task': 'bin_nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/DocNLI_dataset/train.json' }, | |
'squad': {'task': 'extractive_qa', 'text_a': 'context', 'text_b': ['question', 'answers'], 'label': None, 'huggingface': True}, | |
'squad_v2': {'task': 'qa', 'text_a': 'context', 'text_b': ['question', 'answers'], 'label': None, 'huggingface': True}, | |
'paws': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True}, | |
'vitaminc': {'task': 'fact_checking', 'text_a': 'evidence', 'text_b': 'claim', 'label': 'label', 'huggingface': True}, | |
'xsum': {'task': 'summarization', 'text_a': 'document', 'text_b': 'summary', 'label': None, 'huggingface': True, 'cliff_path': 'data/model_generated_data/cliff_summ/xsum_train.jsonl'}, | |
'stsb': {'task': 'sts', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True}, | |
'sick': {'task': 'sts', 'text_a': 'sentence_A', 'text_b': 'sentence_B', 'label': 'relatedness_score', 'huggingface': True}, | |
'race': {'task': 'qa', 'text_a': 'article', 'text_b': ['question', 'options'], 'label': 'answer', 'huggingface': True}, | |
'race_val': {'task': 'qa', 'text_a': 'article', 'text_b': ['question', 'options'], 'label': 'answer', 'huggingface': True}, | |
'anli_r1': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True}, | |
'anli_r2': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True}, | |
'anli_r3': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True}, | |
'snli': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True}, | |
'wikihow': {'task': 'summarization', 'text_a': 'text', 'text_b': 'headline', 'label': None, 'huggingface': False, 'using_hf_api': True, 'data_dir': 'data/wikihow_raw'}, | |
'mrpc': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label','huggingface': True}, | |
'mrpc_val': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label','huggingface': True}, | |
'paws_val': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True}, | |
'paws_unlabeled': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True}, | |
'msmarco': {'task': 'ir', 'text_a': 'passages', 'text_b': ['query', 'answers'], 'label': None,'huggingface': True}, | |
'paws_qqp': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': None,'huggingface': False, 'using_hf_api': False, 'using_pandas': True, 'data_path':'paws_qqp/output/train.tsv' }, | |
'wiki103': {'task': 'paraphrase', 'text_a': 'original_sent', 'text_b': 'paraphrase', 'label': None,'huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json': True, 'data_path':'data/model_generated_data/backtranslation/wiki103_single_sent_backtranslation.json'}, | |
'qqp': {'task': 'paraphrase', 'text_a':'question1', 'text_b':'question2', 'label': 'label', 'huggingface': True}, | |
'qqp_val': {'task': 'paraphrase', 'text_a':'question1', 'text_b':'question2', 'label': 'label', 'huggingface': True}, | |
'wmt17xxx': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': True, 'data_path':'data/wmt/wmt17/2017-da.csv' }, | |
'wmt15': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt15_eval.jsonl' }, | |
'wmt16': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt16_eval.jsonl' }, | |
'wmt17': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt17_eval.jsonl' }, | |
'wmt18': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt18_eval.jsonl' }, | |
'wmt19': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt19_eval.jsonl' }, | |
'squad_v2_new': {'task': 'qa', 'huggingface': True}, | |
'adversarial_qa': {'task': 'qa', 'huggingface': True}, | |
'drop': {'task': 'qa', 'huggingface': True}, | |
'duorc_self': {'task': 'qa', 'huggingface': True}, | |
'duorc_paraphrase': {'task': 'qa', 'huggingface': True}, | |
'quoref': {'task': 'qa', 'huggingface': True}, | |
'hotpot_qa_distractor': {'task': 'qa', 'huggingface': True}, | |
'hotpot_qa_fullwiki': {'task': 'qa', 'huggingface': True}, | |
'newsqa': {'task': 'qa', 'using_json': True, 'raw_json': True, 'data_path': 'data/newsqa_raw/combined-newsqa-data-v1.json'}, | |
'ropes': {'task': 'qa', 'huggingface': True}, | |
'boolq': {'task': 'qa', 'huggingface': True}, | |
'eraser_multi_rc': {'task': 'qa', 'huggingface': True}, | |
'quail': {'task': 'qa', 'huggingface': True}, | |
'sciq': {'task': 'qa', 'huggingface': True}, | |
'strategy_qa': {'task': 'qa', 'huggingface': True}, | |
'gap': {'task': 'coreference', 'huggingface': True}, | |
} | |
class QA2D(): | |
def __init__(self, batch_size=32, device='cuda', verbose=True) -> None: | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
self.tokenizer = BartTokenizer.from_pretrained("MarkS/bart-base-qa2d") | |
self.model = BartForConditionalGeneration.from_pretrained("MarkS/bart-base-qa2d").to(device) | |
self.batch_size = batch_size | |
self.device=device | |
self.verbose = verbose | |
def generate(self, questions: list, answers: list): | |
assert len(questions) == len(answers) | |
qa_list = [] | |
for q, a in zip(questions, answers): | |
qa_list.append(f"question: {q} answer: {a}") | |
output = [] | |
for qa_pairs in tqdm( | |
self.chunks(qa_list, self.batch_size), | |
desc="QA to Declarative", | |
total=int(len(qa_list)/self.batch_size), | |
disable=(not self.verbose) | |
): | |
input_text = qa_pairs | |
input_token = self.tokenizer( | |
input_text, return_tensors='pt', padding=True, truncation=True).to(self.device) | |
dec_sents = self.model.generate( | |
input_token.input_ids, max_length=512) | |
result = self.tokenizer.batch_decode( | |
dec_sents, skip_special_tokens=True) | |
output.extend(result) | |
return output | |
def chunks(self, lst, n): | |
"""Yield successive n-sized chunks from lst.""" | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
class QAnswering(): | |
""" | |
To answer not-answerable questions | |
""" | |
def __init__(self, batch_size=32, device='cuda') -> None: | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
self.tokenizer = T5Tokenizer.from_pretrained( | |
"valhalla/t5-base-qa-qg-hl") | |
self.model = T5ForConditionalGeneration.from_pretrained( | |
"valhalla/t5-base-qa-qg-hl").to(device) | |
self.batch_size = batch_size | |
self.device = device | |
def generate(self, questions: list, contexts: list): | |
assert len(questions) == len(contexts) | |
answers = [] | |
for qs, cs in tqdm(zip(self.chunks(questions, self.batch_size), self.chunks(contexts, self.batch_size)), desc="Generating Answers for not answerable", total=int(len(questions)/self.batch_size)): | |
qc_pairs = [] | |
assert len(qs) == len(cs) | |
for one_q, one_c in zip(qs, cs): | |
qc_pairs.append(f"""question: {one_q} context: {one_c}""") | |
input_ids = self.tokenizer( | |
qc_pairs, padding=True, truncation=True, return_tensors='pt').to(self.device).input_ids | |
outputs = self.model.generate(input_ids, max_length=512) | |
answers.extend(self.tokenizer.batch_decode( | |
outputs, skip_special_tokens=True)) | |
return answers | |
def chunks(self, lst, n): | |
"""Yield successive n-sized chunks from lst.""" | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
class MLMGeneratorWithPairedData(): | |
def __init__(self, corpra: list, device='cuda', batch_size=8, mask_percent=0.25) -> None: | |
self.device = device | |
self.tokenizer = transformers.DistilBertTokenizer.from_pretrained( | |
"distilbert-base-uncased") | |
self.model = transformers.DistilBertForMaskedLM.from_pretrained( | |
"distilbert-base-uncased").to(self.device) | |
self.mask_percent = mask_percent | |
self.batch_size = batch_size | |
self.dataset = corpra # text needs to be noised | |
def chunks(self, lst, n): | |
"""Yield successive n-sized chunks from lst.""" | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
def generate(self): | |
sents_output = [] | |
for examples in tqdm(self.chunks(self.dataset, self.batch_size), total=int(len(self.dataset)/self.batch_size), desc="MLM Generating"): | |
sents_to_be_noised = [each for each in examples] | |
sents_noised = self.mlm_infiller(sents_to_be_noised) | |
sents_output.extend(sents_noised) | |
return sents_output | |
def mlm_infiller(self, batch): | |
""" | |
input a batch of sentences, list | |
""" | |
masked_batch = [] | |
masked_batch_ids = [] | |
for each_sent in batch: | |
sent_tokens = self.tokenizer.tokenize(each_sent) | |
sent_token_ids = self.tokenizer(each_sent)['input_ids'] | |
mask_list = sample(list(range(len(sent_tokens))), int( | |
self.mask_percent * len(sent_tokens))) | |
sent_tokens = [ | |
each if i not in mask_list else self.tokenizer.mask_token for i, each in enumerate(sent_tokens)] | |
masked_batch_ids.append( | |
[each if i-1 not in mask_list else self.tokenizer.mask_token_id for i, each in enumerate(sent_token_ids)]) | |
masked_batch.append(' '.join(sent_tokens)) | |
inputs = self.tokenizer( | |
masked_batch, padding=True, truncation=True, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
logits = self.model(**inputs).logits | |
infill_tokens = [] | |
for i in range(len(masked_batch)): | |
mask_token_index = (inputs.input_ids == self.tokenizer.mask_token_id)[ | |
i].nonzero(as_tuple=True)[0] | |
predicted_token_id = logits[i, mask_token_index].argmax(axis=-1) | |
infill_tokens.append(predicted_token_id) | |
infilled_sent = [] | |
for masked_sent_ids, infill_token in zip(masked_batch_ids, infill_tokens): | |
for infill_one_token in infill_token: | |
for i, each_id in enumerate(masked_sent_ids): | |
if each_id == self.tokenizer.mask_token_id: | |
masked_sent_ids[i] = infill_one_token | |
break | |
infilled_sent.append(self.tokenizer.decode( | |
masked_sent_ids, skip_special_tokens=True)) | |
return infilled_sent | |
class ExtractiveSummarizationGenerator(): | |
def __init__(self) -> None: | |
pass | |
def generate(self, texts): | |
''' | |
texts: list of string | |
''' | |
from summa.summarizer import summarize | |
summaries = [] | |
for text in tqdm(texts, desc="Extracting Summary"): | |
for prop in range(1, 20): | |
summ = summarize(text, ratio=prop/20.) | |
if len(summ) > 0: | |
break | |
summaries.append(summ) | |
return summaries | |
class DataGenerator(): | |
def __init__(self, dataset_names) -> None: | |
self.dataset_names = dataset_names | |
self.datasets = dict() | |
self.t5_qa = None | |
self.t5_tokenizer = None | |
self.load_dataset_from_huggingface() | |
def load_dataset_from_huggingface(self): | |
for each_dataset in self.dataset_names: | |
if DATASET_CONFIG[each_dataset].get('huggingface'): | |
self.datasets[each_dataset] = load_dataset( | |
*DATASET_HUGGINGFACE[each_dataset][:-1])[DATASET_HUGGINGFACE[each_dataset][-1]] | |
elif DATASET_CONFIG[each_dataset].get('using_hf_api'): | |
self.datasets[each_dataset] = load_dataset( | |
*DATASET_HUGGINGFACE[each_dataset][:-1], data_dir=DATASET_CONFIG[each_dataset]['data_dir'])[DATASET_HUGGINGFACE[each_dataset][-1]] | |
elif DATASET_CONFIG[each_dataset].get('using_pandas'): | |
if DATASET_CONFIG[each_dataset]['data_path'].split('.')[-1] == 'tsv': | |
self.datasets[each_dataset] = pd.read_csv( | |
DATASET_CONFIG[each_dataset]['data_path'], sep='\t') | |
elif DATASET_CONFIG[each_dataset]['data_path'].split('.')[-1] == 'csv': | |
self.datasets[each_dataset] = pd.read_csv( | |
DATASET_CONFIG[each_dataset]['data_path']) | |
elif DATASET_CONFIG[each_dataset].get('using_json'): | |
self.datasets[each_dataset] = [] | |
if DATASET_CONFIG[each_dataset].get('raw_json'): | |
with open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8') as f: | |
self.datasets[each_dataset] = json.load(f) | |
else: | |
try: | |
json_file = json.load( | |
open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8')) | |
for example in json_file: | |
self.datasets[each_dataset].append(example) | |
except: | |
with open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8') as f: | |
for example in f: | |
self.datasets[each_dataset].append( | |
json.loads(example)) | |
else: | |
error('unable to locate raw dataset...') | |
def process_squad(self): | |
from rake_nltk import Rake | |
r = Rake() | |
topk = 5 | |
threshold = 0.6 | |
output = [] | |
label = -1 | |
for example in tqdm(self.datasets['squad'], desc=f'Constructing squad'): | |
text_a = example[DATASET_CONFIG['squad']['text_a']] | |
question = example[DATASET_CONFIG['squad']['text_b'][0]] | |
answer = example[DATASET_CONFIG['squad'] | |
['text_b'][1]]['text'] # a list | |
text_b = [question+' '+answer_ele for answer_ele in answer] | |
text_c = [] | |
r.extract_keywords_from_text(text_a) | |
keywords_in_context = r.get_ranked_phrases()[:topk] | |
for each_keyword in keywords_in_context: | |
# then it is an incorrect answer | |
if sentence_bleu([answer_ele.lower().split() for answer_ele in answer], each_keyword.split(), weights=(0.33, 0.33, 0.33)) < threshold: | |
text_c.append(question+' '+each_keyword) | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_squad_v2(self): | |
# first collect answerable items | |
not_answerable_contexts = [] | |
not_answerable_questions = [] | |
not_answerable_answers = [] | |
answerable_contexts = [] | |
answerable_questions = [] | |
answerable_answers = [] | |
qa_generator = QAnswering(batch_size=32, device='cuda') | |
qa2d_generator = QA2D(batch_size=32, device='cuda') | |
for example in tqdm(self.datasets['squad_v2'], desc=f'Collecting (not)answerable examples'): | |
if len(example['answers']['text']) == 0: | |
not_answerable_contexts.append(example['context']) | |
not_answerable_questions.append(example['question']) | |
else: | |
answerable_contexts.append(example['context']) | |
answerable_questions.append(example['question']) | |
answerable_answers.append(example['answers']['text'][0]) | |
not_answerable_answers = qa_generator.generate( | |
not_answerable_questions, not_answerable_contexts) | |
answerable_declarative_sents = qa2d_generator.generate( | |
answerable_questions, answerable_answers) | |
not_answerable_declarative_sents = qa2d_generator.generate( | |
not_answerable_questions, not_answerable_answers) | |
output = [] | |
for i, dec_sent in enumerate(answerable_declarative_sents): | |
output.append({ | |
'text_a': answerable_contexts[i], | |
'text_b': [dec_sent], | |
'text_c': [], | |
'label': 1 | |
}) | |
for i, dec_sent in enumerate(not_answerable_declarative_sents): | |
output.append({ | |
'text_a': not_answerable_contexts[i], | |
'text_b': [dec_sent], | |
'text_c': [], | |
'label': 0 | |
}) | |
return output | |
def process_race(self): | |
qa2d_generator = QA2D(batch_size=32, device='cuda') | |
option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3} | |
output = [] | |
correct_context = [] | |
correct_question = [] | |
correct_answer = [] | |
wrong_context = [] | |
wrong_question = [] | |
wrong_answer = [] | |
for example in tqdm(self.datasets['race'], desc=f'Constructing race'): | |
text_a = example[DATASET_CONFIG['race']['text_a']] | |
label = -1 | |
question = example[DATASET_CONFIG['race']['text_b'][0]] | |
if "_" in question: | |
answer_id = option_dict[example[DATASET_CONFIG['race']['label']]] | |
for i, options in enumerate(example[DATASET_CONFIG['race']['text_b'][1]]): | |
if i == answer_id: | |
output.append({ | |
'text_a': text_a, | |
'text_b': [' '.join(question.replace("_", " "+options+" ").split())], | |
'text_c': [], | |
'label': 1 | |
}) | |
else: | |
output.append({ | |
'text_a': text_a, | |
'text_b': [' '.join(question.replace("_", " "+options+" ").split())], | |
'text_c': [], | |
'label': 0 | |
}) | |
else: | |
answer_id = option_dict[example[DATASET_CONFIG['race']['label']]] | |
for i, options in enumerate(example[DATASET_CONFIG['race']['text_b'][1]]): | |
if i == answer_id: | |
output.append({ | |
'text_a': text_a, | |
'text_b': [question], | |
'text_c': [options], | |
'label': 1 | |
}) | |
else: | |
output.append({ | |
'text_a': text_a, | |
'text_b': [question], | |
'text_c': [options], | |
'label': 0 | |
}) | |
return output | |
def process_race_val(self): | |
qa2d_generator = QA2D(batch_size=32, device='cuda') | |
option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3} | |
output = [] | |
correct_context = [] | |
correct_question = [] | |
correct_answer = [] | |
wrong_context = [] | |
wrong_question = [] | |
wrong_answer = [] | |
for example in tqdm(self.datasets['race_val'], desc=f'Constructing race_val'): | |
text_a = example[DATASET_CONFIG['race_val']['text_a']] | |
label = -1 | |
question = example[DATASET_CONFIG['race_val']['text_b'][0]] | |
if "_" in question: | |
answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]] | |
for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]): | |
if i == answer_id: | |
output.append({ | |
'text_a': text_a, | |
'text_b': [' '.join(question.replace("_", " "+options+" ").split())], | |
'text_c': [], | |
'label': 1 | |
}) | |
else: | |
output.append({ | |
'text_a': text_a, | |
'text_b': [' '.join(question.replace("_", " "+options+" ").split())], | |
'text_c': [], | |
'label': 0 | |
}) | |
else: | |
answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]] | |
for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]): | |
if i == answer_id: | |
correct_context.append(text_a) | |
correct_question.append(question) | |
correct_answer.append(options) | |
else: | |
wrong_context.append(text_a) | |
wrong_question.append(question) | |
wrong_answer.append(options) | |
correct_declarative = qa2d_generator.generate( | |
correct_question, correct_answer) | |
wrong_declarative = qa2d_generator.generate( | |
wrong_question, wrong_answer) | |
assert len(correct_context) == len(correct_declarative) | |
assert len(wrong_context) == len(wrong_declarative) | |
for context, dec in zip(correct_context, correct_declarative): | |
output.append({ | |
'text_a': context, | |
'text_b': [dec], | |
'text_c': [], | |
'label': 1 | |
}) | |
for context, dec in zip(wrong_context, wrong_declarative): | |
output.append({ | |
'text_a': context, | |
'text_b': [dec], | |
'text_c': [], | |
'label': 0 | |
}) | |
return output | |
def process_race_test(self): | |
option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3} | |
output = [] | |
for example in tqdm(self.datasets['race_test'], desc=f'Constructing race_test'): | |
text_a = example[DATASET_CONFIG['race_test']['text_a']] | |
text_b = [] # pos | |
text_c = [] # neg | |
label = -1 | |
question = example[DATASET_CONFIG['race_test']['text_b'][0]] | |
if "_" in question: | |
answer_id = option_dict[example[DATASET_CONFIG['race_test']['label']]] | |
for i, options in enumerate(example[DATASET_CONFIG['race_test']['text_b'][1]]): | |
if i == answer_id: | |
text_b.append(' '.join(question.replace( | |
"_", " "+options+" ").split())) | |
else: | |
text_c.append(' '.join(question.replace( | |
"_", " "+options+" ").split())) | |
else: | |
answer_id = option_dict[example[DATASET_CONFIG['race_test']['label']]] | |
for i, options in enumerate(example[DATASET_CONFIG['race_test']['text_b'][1]]): | |
if i == answer_id: | |
text_b.append(question+" "+options+" ") | |
else: | |
text_c.append(question+" "+options+" ") | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_xsum(self): | |
''' | |
text_a: raw_text | |
text_b: raw_summary + ***extractive summ*** removed | |
text_c: cliff xsum + DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b*** | |
''' | |
output = [] | |
gold_summary = [example[DATASET_CONFIG['xsum']['text_b']] | |
for example in self.datasets['xsum']] | |
ext_summarizer = ExtractiveSummarizationGenerator() | |
extracted_summ = ext_summarizer.generate( | |
[example[DATASET_CONFIG['xsum']['text_a']] for example in self.datasets['xsum']]) | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25) | |
gold_summary_hallucinated = mlm_hallucinator.generate() | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25) | |
extracted_summ_hallucinated = mlm_hallucinator.generate() | |
assert len(self.datasets['xsum']) == len(gold_summary_hallucinated) and len( | |
self.datasets['xsum']) == len(extracted_summ_hallucinated) | |
for i, example in tqdm(enumerate(self.datasets['xsum']), desc="Constructing xsum", total=len(self.datasets['xsum'])): | |
text_a = example[DATASET_CONFIG['xsum']['text_a']] | |
text_b = [gold_summary[i], extracted_summ[i]] | |
text_c = [gold_summary_hallucinated[i], | |
extracted_summ_hallucinated[i]] | |
label = -1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_cnndm(self): | |
''' | |
text_a: raw_text | |
text_b: raw_summary + ***extractive summ*** removed | |
text_c: DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b*** | |
''' | |
# interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000 | |
output = [] | |
gold_summary = [example[DATASET_CONFIG['cnndm']['text_b']] | |
for example in self.datasets['cnndm']] | |
ext_summarizer = ExtractiveSummarizationGenerator() | |
extracted_summ = ext_summarizer.generate( | |
[example[DATASET_CONFIG['cnndm']['text_a']] for example in self.datasets['cnndm']]) | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25) | |
gold_summary_hallucinated = mlm_hallucinator.generate() | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25) | |
extracted_summ_hallucinated = mlm_hallucinator.generate() | |
assert len(self.datasets['cnndm']) == len(gold_summary_hallucinated) and len( | |
self.datasets['cnndm']) == len(extracted_summ_hallucinated) | |
for i, example in tqdm(enumerate(self.datasets['cnndm']), desc="Constructing cnndm", total=len(self.datasets['cnndm'])): | |
text_a = example[DATASET_CONFIG['cnndm']['text_a']] | |
text_b = [gold_summary[i], extracted_summ[i]] | |
text_c = [gold_summary_hallucinated[i], | |
extracted_summ_hallucinated[i]] | |
label = -1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_wikihow(self): | |
''' | |
text_a: raw_text | |
text_b: raw_summary + ***extractive summ*** removed | |
text_c: DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b*** | |
''' | |
# interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000 | |
output = [] | |
gold_summary = [example[DATASET_CONFIG['wikihow']['text_b']] | |
for example in self.datasets['wikihow']] | |
ext_summarizer = ExtractiveSummarizationGenerator() | |
extracted_summ = ext_summarizer.generate( | |
[example[DATASET_CONFIG['wikihow']['text_a']] for example in self.datasets['wikihow']]) | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25) | |
gold_summary_hallucinated = mlm_hallucinator.generate() | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25) | |
extracted_summ_hallucinated = mlm_hallucinator.generate() | |
assert len(self.datasets['wikihow']) == len(gold_summary_hallucinated) and len( | |
self.datasets['wikihow']) == len(extracted_summ_hallucinated) | |
for i, example in tqdm(enumerate(self.datasets['wikihow']), desc="Constructing wikihow", total=len(self.datasets['wikihow'])): | |
text_a = example[DATASET_CONFIG['wikihow']['text_a']] | |
text_b = [gold_summary[i], extracted_summ[i]] | |
text_c = [gold_summary_hallucinated[i], | |
extracted_summ_hallucinated[i]] | |
label = -1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_wiki103(self): | |
output = [] | |
paraphrases = [example[DATASET_CONFIG['wiki103']['text_b']] | |
for example in self.datasets['wiki103']] | |
mlm_hallucinator = MLMGeneratorWithPairedData( | |
corpra=paraphrases, device='cuda:3', batch_size=64, mask_percent=0.25) | |
paraphrase_hallucinated = mlm_hallucinator.generate() | |
assert len(self.datasets['wiki103']) == len(paraphrase_hallucinated) | |
for i, example in tqdm(enumerate(self.datasets['wiki103']), desc=f'Constructing wiki103'): | |
output.append({ | |
'text_a': example[DATASET_CONFIG['wiki103']['text_a']], | |
'text_b': [example[DATASET_CONFIG['wiki103']['text_b']]], | |
'text_c': [], | |
'label': 1 | |
}) | |
output.append({ | |
'text_a': example[DATASET_CONFIG['wiki103']['text_a']], | |
'text_b': [paraphrase_hallucinated[i]], | |
'text_c': [], | |
'label': 0 | |
}) | |
return output | |
def process_mnli(self): | |
output = [] | |
for example in tqdm(self.datasets['mnli'], desc=f'Constructing mnli'): | |
text_a = example[DATASET_CONFIG['mnli']['text_a']] | |
text_b = [example[DATASET_CONFIG['mnli']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['mnli']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_nli_fever(self): | |
output = [] | |
for example in tqdm(self.datasets['nli_fever'], desc=f'Constructing nli_fever'): | |
text_a = example[DATASET_CONFIG['nli_fever']['text_a']] | |
text_b = [example[DATASET_CONFIG['nli_fever']['text_b']]] | |
text_c = [] | |
raw_label = example[DATASET_CONFIG['nli_fever']['label']] | |
if raw_label == 'SUPPORTS': # convert to nli style label | |
label = 0 | |
elif raw_label == 'REFUTES': | |
label = 2 | |
else: | |
label = 1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_doc_nli(self): | |
output = [] | |
for example in tqdm(self.datasets['doc_nli'], desc=f'Constructing doc_nli'): | |
text_a = example[DATASET_CONFIG['doc_nli']['text_a']] | |
text_b = [example[DATASET_CONFIG['doc_nli']['text_b']]] | |
text_c = [] | |
raw_label = example[DATASET_CONFIG['doc_nli']['label']] | |
if raw_label == 'entailment': # convert to paraphrase style label | |
label = 1 | |
else: | |
label = 0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_anli_r1(self): | |
output = [] | |
for example in tqdm(self.datasets['anli_r1'], desc=f'Constructing anli_r1'): | |
text_a = example[DATASET_CONFIG['anli_r1']['text_a']] | |
text_b = [example[DATASET_CONFIG['anli_r1']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['anli_r1']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_anli_r2(self): | |
output = [] | |
for example in tqdm(self.datasets['anli_r2'], desc=f'Constructing anli_r2'): | |
text_a = example[DATASET_CONFIG['anli_r2']['text_a']] | |
text_b = [example[DATASET_CONFIG['anli_r2']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['anli_r2']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_anli_r3(self): | |
output = [] | |
for example in tqdm(self.datasets['anli_r3'], desc=f'Constructing anli_r3'): | |
text_a = example[DATASET_CONFIG['anli_r3']['text_a']] | |
text_b = [example[DATASET_CONFIG['anli_r3']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['anli_r3']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_snli(self): | |
output = [] | |
for example in tqdm(self.datasets['snli'], desc=f'Constructing snli'): | |
text_a = example[DATASET_CONFIG['snli']['text_a']] | |
text_b = [example[DATASET_CONFIG['snli']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['snli']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_paws(self): | |
output = [] | |
for example in tqdm(self.datasets['paws'], desc=f'Constructing paws'): | |
text_a = example[DATASET_CONFIG['paws']['text_a']] | |
text_b = [example[DATASET_CONFIG['paws']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['paws']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_vitaminc(self): | |
output = [] | |
for example in tqdm(self.datasets['vitaminc'], desc=f'Constructing vitaminc'): | |
text_a = example[DATASET_CONFIG['vitaminc']['text_a']] | |
text_b = [example[DATASET_CONFIG['vitaminc']['text_b']]] | |
text_c = [] | |
raw_label = example[DATASET_CONFIG['vitaminc']['label']] | |
if raw_label == 'SUPPORTS': # convert to nli style label | |
label = 0 | |
elif raw_label == 'REFUTES': | |
label = 2 | |
else: | |
label = 1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_stsb(self): | |
output = [] | |
for example in tqdm(self.datasets['stsb'], desc=f'Constructing stsb'): | |
text_a = example[DATASET_CONFIG['stsb']['text_a']] | |
text_b = [example[DATASET_CONFIG['stsb']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['stsb']['label']] / 5.0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_sick(self): | |
output = [] | |
for example in tqdm(self.datasets['sick'], desc=f'Constructing sick'): | |
text_a = example[DATASET_CONFIG['sick']['text_a']] | |
text_b = [example[DATASET_CONFIG['sick']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['sick']['label']] / 5.0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_mrpc(self): | |
output = [] | |
for example in tqdm(self.datasets['mrpc'], desc=f'Constructing mrpc'): | |
text_a = example[DATASET_CONFIG['mrpc']['text_a']] | |
text_b = [example[DATASET_CONFIG['mrpc']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['mrpc']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_mrpc_val(self): | |
output = [] | |
for example in tqdm(self.datasets['mrpc_val'], desc=f'Constructing mrpc_val'): | |
text_a = example[DATASET_CONFIG['mrpc_val']['text_a']] | |
text_b = [example[DATASET_CONFIG['mrpc_val']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['mrpc_val']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_paws_val(self): | |
output = [] | |
for example in tqdm(self.datasets['paws_val'], desc=f'Constructing paws_val'): | |
text_a = example[DATASET_CONFIG['paws_val']['text_a']] | |
text_b = [example[DATASET_CONFIG['paws_val']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['paws_val']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_paws_unlabeled(self): | |
output = [] | |
for example in tqdm(self.datasets['paws_unlabeled'], desc=f'Constructing paws_unlabeled'): | |
text_a = example[DATASET_CONFIG['paws_unlabeled']['text_a']] | |
text_b = [example[DATASET_CONFIG['paws_unlabeled']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['paws_unlabeled']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_qqp(self): | |
output = [] | |
for example in tqdm(self.datasets['qqp'], desc=f'Constructing qqp'): | |
text_a = example[DATASET_CONFIG['qqp']['text_a']] | |
text_b = [example[DATASET_CONFIG['qqp']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['qqp']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_qqp_val(self): | |
output = [] | |
for example in tqdm(self.datasets['qqp_val'], desc=f'Constructing qqp_val'): | |
text_a = example[DATASET_CONFIG['qqp_val']['text_a']] | |
text_b = [example[DATASET_CONFIG['qqp_val']['text_b']]] | |
text_c = [] | |
label = example[DATASET_CONFIG['qqp_val']['label']] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_msmarco(self): | |
qa2d_generator = QA2D(batch_size=32, device='cuda') | |
output = [] | |
correct_contexts = [] | |
correct_questions = [] | |
correct_answers = [] | |
wrong_contexts = [] | |
wrong_questions = [] | |
wrong_answers = [] | |
filtered_examples = [] | |
questions = [] | |
answers = [] | |
declaratives = [] | |
for example in tqdm(self.datasets['msmarco'], desc=f'Collecting msmarco'): | |
if sum(example['passages']['is_selected']) > 0: # has answer | |
questions.append(example['query']) | |
answers.append(example['answers'][0] if len( | |
example['wellFormedAnswers']) == 0 else example['wellFormedAnswers'][0]) | |
filtered_examples.append(example) | |
for example in filtered_examples: | |
for i, is_selected in enumerate(example['passages']['is_selected']): | |
if is_selected == 1: | |
output.append({ | |
'text_a': example['passages']['passage_text'][i], | |
'text_b': [example['query']], | |
'text_c': [], | |
'label': 1 | |
} | |
) | |
else: | |
output.append({ | |
'text_a': example['passages']['passage_text'][i], | |
'text_b': [example['query']], | |
'text_c': [], | |
'label': 0 | |
} | |
) | |
return output | |
def process_paws_qqp(self): | |
output = [] | |
for i in range(len(self.datasets['paws_qqp'])): | |
text_a = self.datasets['paws_qqp'].iloc[i]['sentence1'][2:-1] | |
text_b = [self.datasets['paws_qqp'].iloc[i]['sentence2'][2:-1]] | |
text_c = [] | |
label = self.datasets['paws_qqp'].iloc[i]['label'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': int(label) | |
}) | |
return output | |
def process_wmt15(self): | |
output = [] | |
for example in self.datasets['wmt15']: | |
text_a = example['reference'] | |
text_b = [example['candidate']] | |
text_c = [] | |
label = example['score'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_wmt16(self): | |
output = [] | |
for example in self.datasets['wmt16']: | |
text_a = example['reference'] | |
text_b = [example['candidate']] | |
text_c = [] | |
label = example['score'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_wmt17(self): | |
output = [] | |
for example in self.datasets['wmt17']: | |
text_a = example['reference'] | |
text_b = [example['candidate']] | |
text_c = [] | |
label = example['score'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_wmt18(self): | |
output = [] | |
for example in self.datasets['wmt18']: | |
text_a = example['reference'] | |
text_b = [example['candidate']] | |
text_c = [] | |
label = example['score'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_wmt19(self): | |
output = [] | |
for example in self.datasets['wmt19']: | |
text_a = example['reference'] | |
text_b = [example['candidate']] | |
text_c = [] | |
label = example['score'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_boolq(self): | |
output = [] | |
for example in self.datasets['boolq']: | |
text_a = example['passage'] | |
text_b = [example['question']] | |
text_c = ["Yes." if example['answer'] else "No."] | |
label = 1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
text_a = example['passage'] | |
text_b = [example['question']] | |
text_c = ["Yes." if not example['answer'] else "No."] | |
label = 0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_eraser_multi_rc(self): | |
output = [] | |
for example in self.datasets['eraser_multi_rc']: | |
text_a = example['passage'] | |
text_b = [example['query_and_answer'].replace("|", "")] | |
text_c = [] | |
label = int(example['label']) | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_quail(self): | |
output = [] | |
for example in self.datasets['quail']: | |
for i, ans in enumerate(example['answers']): | |
text_a = example['context'] | |
text_b = [example['question']] | |
text_c = [ans] | |
label = 1 if i == example['correct_answer_id'] else 0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_sciq(self): | |
output = [] | |
for example in self.datasets['sciq']: | |
text_a = example['support'] | |
output.append({ | |
'text_a': text_a, | |
'text_b': [example['question']], | |
'text_c': [example['distractor1']], | |
'label': 0 | |
}) | |
output.append({ | |
'text_a': text_a, | |
'text_b': [example['question']], | |
'text_c': [example['distractor2']], | |
'label': 0 | |
}) | |
output.append({ | |
'text_a': text_a, | |
'text_b': [example['question']], | |
'text_c': [example['distractor3']], | |
'label': 0 | |
}) | |
output.append({ | |
'text_a': text_a, | |
'text_b': [example['question']], | |
'text_c': [example['correct_answer']], | |
'label': 1 | |
}) | |
return output | |
def process_strategy_qa(self): | |
output = [] | |
for example in self.datasets['strategy_qa']: | |
text_a = ' '.join(example['facts']) | |
text_b = [example['question']] | |
text_c = ["Yes." if example['answer'] else "No."] | |
label = 1 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
text_a = ' '.join(example['facts']) | |
text_b = [example['question']] | |
text_c = ["Yes." if not example['answer'] else "No."] | |
label = 0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def process_gap(self): | |
output = [] | |
for example in self.datasets['gap']: | |
text_a = example['Text'] | |
text_b = [example['Text'][:example['Pronoun-offset']]+example['A']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]] | |
text_c = [] | |
label = 1 if example['A-coref'] else 0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
text_a = example['Text'] | |
text_b = [example['Text'][:example['Pronoun-offset']]+example['B']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]] | |
text_c = [] | |
label = 1 if example['B-coref'] else 0 | |
output.append({ | |
'text_a': text_a, | |
'text_b': text_b, | |
'text_c': text_c, | |
'label': label | |
}) | |
return output | |
def init_qa_t5(self): | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
if self.t5_qa is None: | |
self.t5_tokenizer = T5Tokenizer.from_pretrained( | |
"t5-base", model_max_length=800) | |
self.t5_qa = T5ForConditionalGeneration.from_pretrained("t5-base") | |
self.t5_qa.to('cuda:1') | |
self.t5_qa.eval() | |
def mask_answer(context, answers): | |
answers = sorted(answers, key=len, reverse=True) | |
for answer in answers: | |
pattern = f'(?<![\w\\-\u2013]){re.escape(answer)}(?![\w\\-\u2013])' | |
context = re.sub(pattern, '', context, flags=re.IGNORECASE) | |
return context | |
def generate_fake_answer(self, context, question, answers): | |
self.init_qa_t5() | |
context_no_answer = self.mask_answer(context, answers) | |
input_ids = self.t5_tokenizer( | |
f'question: {question} context: {context_no_answer}', | |
return_tensors="pt", | |
truncation='only_first' | |
).input_ids.to(self.t5_qa.device) | |
outputs = self.t5_qa.generate( | |
input_ids, | |
max_new_tokens=40, | |
remove_invalid_values=True | |
) | |
return self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def negative_sample_qa(self, samples, negative_sample_no_ans_only=True): | |
outputs = [] | |
for context, question, answers in samples: | |
if answers: | |
outputs.append({ | |
'text_a': context, | |
'text_b': [question], | |
'text_c': answers, | |
'label': 1 | |
}) | |
if not answers or not negative_sample_no_ans_only: | |
fake_answer = self.generate_fake_answer( | |
context, question, answers) | |
outputs.append({ | |
'text_a': context, | |
'text_b': [question], | |
'text_c': [fake_answer], | |
'label': 0 | |
}) | |
return outputs | |
def process_squad_v2_new(self): | |
samples = ( | |
(sample['context'], sample['question'], sample['answers']['text']) | |
for sample in tqdm(self.datasets['squad_v2_new'], desc=f'squad_v2_new') | |
) | |
return self.negative_sample_qa(samples) | |
def process_adversarial_qa(self): | |
samples = ( | |
(sample['context'], sample['question'], sample['answers']['text']) | |
for sample in tqdm(self.datasets['adversarial_qa'], desc=f'adversarial_qa') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_drop(self): | |
samples = ( | |
(sample['passage'], sample['question'], | |
sample['answers_spans']['spans']) | |
for sample in tqdm(self.datasets['drop'], desc=f'drop') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_duorc_self(self): | |
samples = ( | |
(sample['plot'], sample['question'], | |
sample['answers']) | |
for sample in tqdm(self.datasets['duorc_self'], desc=f'duorc_self') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_duorc_paraphrase(self): | |
samples = ( | |
(sample['plot'], sample['question'], | |
sample['answers']) | |
for sample in tqdm(self.datasets['duorc_paraphrase'], desc=f'duorc_paraphrase') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_quoref(self): | |
samples = ( | |
(sample['context'], sample['question'], sample['answers']['text']) | |
for sample in tqdm(self.datasets['quoref'], desc=f'quoref') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def prepare_hotpot_qa_samples(dateset): | |
for sample in dateset: | |
question = sample['question'] | |
answer = sample['answer'] | |
supporting_docs = set(sample['supporting_facts']['title']) | |
irrelevant_docs = [] | |
context_paragraphs = [] | |
for title, setences in zip(sample['context']['title'], sample['context']['sentences']): | |
doc = ''.join(setences) | |
if title in supporting_docs: | |
context_paragraphs.append(doc) | |
else: | |
irrelevant_docs.append(doc) | |
# Add some irrelevant documents | |
if irrelevant_docs and len(context_paragraphs) < 4: | |
context_paragraphs.append(random.choice(irrelevant_docs)) | |
random.shuffle(context_paragraphs) | |
yield '\n'.join(context_paragraphs), question, [answer] | |
def process_hotpot_qa_distractor(self): | |
samples = self.prepare_hotpot_qa_samples( | |
tqdm(self.datasets['hotpot_qa_distractor'], | |
desc=f'hotpot_qa_distractor') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_hotpot_qa_fullwiki(self): | |
samples = self.prepare_hotpot_qa_samples( | |
tqdm(self.datasets['hotpot_qa_fullwiki'], | |
desc=f'hotpot_qa_fullwiki') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_newsqa(self): | |
def get_samples(dataset): | |
for story in tqdm(dataset['data'], desc='newsqa'): | |
if story['type'] != 'train': | |
continue | |
context = story['text'] | |
for question in story['questions']: | |
if question.get('isQuestionBad', 0.) > 0.2: | |
continue | |
answers = [] | |
if 's' in question['consensus']: | |
start = question['consensus']['s'] | |
end = question['consensus']['e'] | |
answers.append(context[start:end].strip()) | |
yield context, question['q'], answers | |
samples = get_samples(self.datasets['newsqa']) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def process_ropes(self): | |
samples = ( | |
( | |
sample['situation'] + ' ' + sample['background'], | |
sample['question'], sample['answers']['text'] | |
) | |
for sample in tqdm(self.datasets['ropes'], desc=f'ropes') | |
) | |
return self.negative_sample_qa(samples, negative_sample_no_ans_only=False) | |
def generate(self): | |
for each_dataset in self.datasets: | |
with open(f'./data/training/{each_dataset}.json', 'w', encoding='utf8') as outfile: | |
outfile.write("") | |
for each_dataset in self.datasets: | |
outputs = eval(f'self.process_{each_dataset}()') | |
for each_output in outputs: | |
dict_write_to_file = { | |
'task': DATASET_CONFIG[each_dataset]['task'], | |
'text_a': each_output['text_a'], # string | |
# list of positive examples | |
'text_b': each_output['text_b'], | |
# list of negative examples | |
'text_c': each_output['text_c'], | |
# original label, if -1 only has positive pairs and negative pairs | |
'orig_label': each_output['label'] | |
} | |
with open(f'./data/training/{each_dataset}.json', 'a', encoding='utf8') as outfile: | |
json.dump(dict_write_to_file, outfile, ensure_ascii=False) | |
outfile.write('\n') | |
if __name__ == "__main__": | |
random.seed(42) | |
gen = DataGenerator(list(DATASET_CONFIG.keys())) | |
gen.generate() | |