Chenxi Whitehouse
commited on
Commit
•
eaaaf3d
1
Parent(s):
4400462
add src files
Browse files- README.md +13 -6
- src/prediction/veracity_prediction.py +121 -0
- src/reranking/bm25_sentences.py +1 -1
- src/reranking/rerank_questions.py +106 -0
- src/retrieval/scraper_for_knowledge_store.py +1 -1
README.md
CHANGED
@@ -47,7 +47,7 @@ The training and dev dataset can be found under [data](https://huggingface.co/ch
|
|
47 |
|
48 |
## Reproduce the baseline
|
49 |
|
50 |
-
Below are the steps to reproduce the baseline results. The main difference from the reported results in the paper is that, instead of requiring direct access to the paid Google Search API, we provide such search results for up to 1000 URLs per claim using different queries, and the scraped text as a knowledge store for retrieval for each claim. This is aimed at reducing the overhead cost of participating in the Shared Task.
|
51 |
|
52 |
|
53 |
### 0. Set up environment
|
@@ -93,28 +93,35 @@ python -m src.reranking.bm25_sentences
|
|
93 |
```
|
94 |
|
95 |
### 3. Generate questions-answer pair for the top sentences
|
96 |
-
We use [BLOOM](https://huggingface.co/bigscience/bloom-7b1) to generate QA paris for each of the top 100 sentence, providing 10 closest claim-QA-pairs from the training set as in-context examples. See [question_generation_top_sentences.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/question_generation_top_sentences.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/
|
97 |
```bash
|
98 |
python -m src.reranking.question_generation_top_sentences
|
99 |
```
|
100 |
|
101 |
### 4. Rerank the QA pairs
|
102 |
-
Using a pre-trained BERT model [bert_dual_encoder.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_dual_encoder.ckpt), we rerank the QA paris and keep top 3 QA paris as evidence. We provide the output file for this step on the dev set [here]().
|
103 |
```bash
|
|
|
104 |
```
|
105 |
|
106 |
|
107 |
### 5. Veracity prediction
|
108 |
-
Finally, given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model [bert_veracity.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_veracity.ckpt) to predict the veracity label.
|
109 |
```bash
|
|
|
110 |
```
|
111 |
The results will be presented as follows:
|
112 |
-
|
|
|
113 |
```
|
114 |
|
115 |
-
We recommend using 0.25 as cut-off score for evaluating the relevance of the evidence. The result for dev and the test set below.
|
116 |
|
|
|
117 |
|
|
|
|
|
|
|
|
|
118 |
|
119 |
## Citation
|
120 |
If you find AVeriTeC useful for your research and applications, please cite us using this BibTeX:
|
|
|
47 |
|
48 |
## Reproduce the baseline
|
49 |
|
50 |
+
Below are the steps to reproduce the baseline results. The main difference from the reported results in the paper is that, instead of requiring direct access to the paid Google Search API, we provide such search results for up to 1000 URLs per claim using different queries, and the scraped text as a knowledge store for retrieval for each claim. This is aimed at reducing the overhead cost of participating in the Shared Task. Another difference is that we also added text scraped from pdf URLs to the knowledge store.
|
51 |
|
52 |
|
53 |
### 0. Set up environment
|
|
|
93 |
```
|
94 |
|
95 |
### 3. Generate questions-answer pair for the top sentences
|
96 |
+
We use [BLOOM](https://huggingface.co/bigscience/bloom-7b1) to generate QA paris for each of the top 100 sentence, providing 10 closest claim-QA-pairs from the training set as in-context examples. See [question_generation_top_sentences.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/question_generation_top_sentences.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_top_k_qa.json).
|
97 |
```bash
|
98 |
python -m src.reranking.question_generation_top_sentences
|
99 |
```
|
100 |
|
101 |
### 4. Rerank the QA pairs
|
102 |
+
Using a pre-trained BERT model [bert_dual_encoder.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_dual_encoder.ckpt), we rerank the QA paris and keep top 3 QA paris as evidence. See [rerank_questions.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/rerank_questions.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_top_3_rerank_qa.json).
|
103 |
```bash
|
104 |
+
python -m reranking.rerank_questions
|
105 |
```
|
106 |
|
107 |
|
108 |
### 5. Veracity prediction
|
109 |
+
Finally, given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model [bert_veracity.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_veracity.ckpt) to predict the veracity label. See [veracity_prediction.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/prediction/veracity_prediction.py) for more argument options. We provide the prediction file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_vericity_prediction.json).
|
110 |
```bash
|
111 |
+
python -m prediction.veracity_prediction
|
112 |
```
|
113 |
The results will be presented as follows:
|
114 |
+
|
115 |
+
```
|
116 |
```
|
117 |
|
|
|
118 |
|
119 |
+
We recommend using 0.25 as cut-off score for evaluating the relevance of the evidence. The result for dev and the test set below.
|
120 |
|
121 |
+
| Model | Split | Q only | Q + A | Veracity @ 0.2 | @ 0.25 | @ 0.3 |
|
122 |
+
|-------------------|-------|--------|-------|----------------|--------|-------|
|
123 |
+
| AVeriTeC-BLOOM-7b | dev | | | | | |
|
124 |
+
| AVeriTeC-BLOOM-7b | test | | | | | |
|
125 |
|
126 |
## Citation
|
127 |
If you find AVeriTeC useful for your research and applications, please cite us using this BibTeX:
|
src/prediction/veracity_prediction.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import tqdm
|
4 |
+
import torch
|
5 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
6 |
+
from data_loaders.SequenceClassificationDataLoader import (
|
7 |
+
SequenceClassificationDataLoader,
|
8 |
+
)
|
9 |
+
from models.SequenceClassificationModule import SequenceClassificationModule
|
10 |
+
|
11 |
+
|
12 |
+
LABEL = [
|
13 |
+
"Supported",
|
14 |
+
"Refuted",
|
15 |
+
"Not Enough Evidence",
|
16 |
+
"Conflicting Evidence/Cherrypicking",
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = argparse.ArgumentParser(
|
22 |
+
description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label."
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"-i",
|
26 |
+
"--claim_with_evidence_file",
|
27 |
+
default="data/dev_top3_questions.json",
|
28 |
+
help="Json file with claim and top question-answer pairs as evidence.",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"-o",
|
32 |
+
"--output_file",
|
33 |
+
default="data_store/dev_veracity.json",
|
34 |
+
help="Json file with the veracity predictions.",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"-ckpt",
|
38 |
+
"--best_checkpoint",
|
39 |
+
type=str,
|
40 |
+
default="pretrained_models/bert_veracity.ckpt",
|
41 |
+
)
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
with open(args.claim_with_evidence_file) as f:
|
45 |
+
examples = json.load(f)
|
46 |
+
|
47 |
+
bert_model_name = "bert-base-uncased"
|
48 |
+
|
49 |
+
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
50 |
+
bert_model = BertForSequenceClassification.from_pretrained(
|
51 |
+
bert_model_name, num_labels=4, problem_type="single_label_classification"
|
52 |
+
)
|
53 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
54 |
+
trained_model = SequenceClassificationModule.load_from_checkpoint(
|
55 |
+
args.best_checkpoint, tokenizer=tokenizer, model=bert_model
|
56 |
+
).to(device)
|
57 |
+
|
58 |
+
dataLoader = SequenceClassificationDataLoader(
|
59 |
+
tokenizer=tokenizer,
|
60 |
+
data_file="this_is_discontinued",
|
61 |
+
batch_size=32,
|
62 |
+
add_extra_nee=False,
|
63 |
+
)
|
64 |
+
|
65 |
+
predictions = []
|
66 |
+
|
67 |
+
for example in tqdm.tqdm(examples):
|
68 |
+
example_strings = []
|
69 |
+
for evidence in example["evidence"]:
|
70 |
+
example_strings.append(
|
71 |
+
dataLoader.quadruple_to_string(
|
72 |
+
example["claim"], evidence["question"], evidence["answer"], ""
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
if (
|
77 |
+
len(example_strings) == 0
|
78 |
+
): # If we found no evidence e.g. because google returned 0 pages, just output NEI.
|
79 |
+
example["label"] = "Not Enough Evidence"
|
80 |
+
continue
|
81 |
+
|
82 |
+
tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings)
|
83 |
+
example_support = torch.argmax(
|
84 |
+
trained_model(tokenized_strings, attention_mask=attention_mask).logits,
|
85 |
+
axis=1,
|
86 |
+
)
|
87 |
+
|
88 |
+
has_unanswerable = False
|
89 |
+
has_true = False
|
90 |
+
has_false = False
|
91 |
+
|
92 |
+
for v in example_support:
|
93 |
+
if v == 0:
|
94 |
+
has_true = True
|
95 |
+
if v == 1:
|
96 |
+
has_false = True
|
97 |
+
if v in (
|
98 |
+
2,
|
99 |
+
3,
|
100 |
+
): # TODO another hack -- we cant have different labels for train and test so we do this
|
101 |
+
has_unanswerable = True
|
102 |
+
|
103 |
+
if has_unanswerable:
|
104 |
+
answer = 2
|
105 |
+
elif has_true and not has_false:
|
106 |
+
answer = 0
|
107 |
+
elif not has_true and has_false:
|
108 |
+
answer = 1
|
109 |
+
else:
|
110 |
+
answer = 3
|
111 |
+
|
112 |
+
json_data = {
|
113 |
+
"claim_id": example["claim_id"],
|
114 |
+
"claim": example["claim"],
|
115 |
+
"evidence": example["evidence"],
|
116 |
+
"label": LABEL[answer],
|
117 |
+
}
|
118 |
+
predictions.append(json_data)
|
119 |
+
|
120 |
+
with open(args.output_file, "w", encoding="utf-8") as output_file:
|
121 |
+
json.dump(predictions, output_file, ensure_ascii=False, indent=4)
|
src/reranking/bm25_sentences.py
CHANGED
@@ -30,7 +30,7 @@ def retrieve_top_k_sentences(query, document, urls, top_k):
|
|
30 |
if __name__ == "__main__":
|
31 |
|
32 |
parser = argparse.ArgumentParser(
|
33 |
-
description="Get top 100 sentences
|
34 |
)
|
35 |
parser.add_argument(
|
36 |
"-k",
|
|
|
30 |
if __name__ == "__main__":
|
31 |
|
32 |
parser = argparse.ArgumentParser(
|
33 |
+
description="Get top 100 sentences with BM25 in the knowledge store."
|
34 |
)
|
35 |
parser.add_argument(
|
36 |
"-k",
|
src/reranking/rerank_questions.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import tqdm
|
5 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
6 |
+
from models.DualEncoderModule import DualEncoderModule
|
7 |
+
|
8 |
+
|
9 |
+
def triple_to_string(x):
|
10 |
+
return " </s> ".join([item.strip() for item in x])
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model."
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"-i",
|
19 |
+
"--top_k_qa_file",
|
20 |
+
default="data/dev_top_k_qa.json",
|
21 |
+
help="Json file with claim and top k generated question-answer pairs.",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"-o",
|
25 |
+
"--output_file",
|
26 |
+
default="data/dev_top_3_rerank_qa.json",
|
27 |
+
help="Json file with the top3 reranked questions.",
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"-ckpt",
|
31 |
+
"--best_checkpoint",
|
32 |
+
type=str,
|
33 |
+
default="pretrained_models/bert_dual_encoder.ckpt",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--top_n",
|
37 |
+
type=int,
|
38 |
+
default=3,
|
39 |
+
help="top_n question answer pairs as evidence to keep.",
|
40 |
+
)
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
with open(args.top_k_qa_file) as f:
|
44 |
+
examples = json.load(f)
|
45 |
+
|
46 |
+
bert_model_name = "bert-base-uncased"
|
47 |
+
|
48 |
+
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
|
49 |
+
bert_model = BertForSequenceClassification.from_pretrained(
|
50 |
+
bert_model_name, num_labels=2, problem_type="single_label_classification"
|
51 |
+
)
|
52 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
53 |
+
trained_model = DualEncoderModule.load_from_checkpoint(
|
54 |
+
args.best_checkpoint, tokenizer=tokenizer, model=bert_model
|
55 |
+
).to(device)
|
56 |
+
|
57 |
+
with open(args.output_file, "w", encoding="utf-8") as output_file:
|
58 |
+
for example in tqdm.tqdm(examples):
|
59 |
+
strs_to_score = []
|
60 |
+
values = []
|
61 |
+
|
62 |
+
bm25_qau = example["bm25_qau"] if "bm25_qau" in example else []
|
63 |
+
claim = example["claim"]
|
64 |
+
|
65 |
+
for question, answer, url in bm25_qau:
|
66 |
+
str_to_score = triple_to_string([claim, question, answer])
|
67 |
+
|
68 |
+
strs_to_score.append(str_to_score)
|
69 |
+
values.append([question, answer, url])
|
70 |
+
|
71 |
+
if len(bm25_qau) > 0:
|
72 |
+
encoded_dict = tokenizer(
|
73 |
+
strs_to_score,
|
74 |
+
max_length=512,
|
75 |
+
padding="longest",
|
76 |
+
truncation=True,
|
77 |
+
return_tensors="pt",
|
78 |
+
).to(device)
|
79 |
+
|
80 |
+
input_ids = encoded_dict["input_ids"]
|
81 |
+
attention_masks = encoded_dict["attention_mask"]
|
82 |
+
|
83 |
+
scores = torch.softmax(
|
84 |
+
trained_model(input_ids, attention_mask=attention_masks).logits,
|
85 |
+
axis=-1,
|
86 |
+
)[:, 1]
|
87 |
+
|
88 |
+
top_n = torch.argsort(scores, descending=True)[: args.top_n]
|
89 |
+
evidence = [
|
90 |
+
{
|
91 |
+
"question": values[i][0],
|
92 |
+
"answer": values[i][1],
|
93 |
+
"url": values[i][2],
|
94 |
+
}
|
95 |
+
for i in top_n
|
96 |
+
]
|
97 |
+
else:
|
98 |
+
evidence = []
|
99 |
+
|
100 |
+
json_data = {
|
101 |
+
"claim_id": example["claim_id"],
|
102 |
+
"claim": claim,
|
103 |
+
"evidence": evidence,
|
104 |
+
}
|
105 |
+
output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n")
|
106 |
+
output_file.flush()
|
src/retrieval/scraper_for_knowledge_store.py
CHANGED
@@ -46,7 +46,7 @@ def scrape_text_from_url(url, temp_name):
|
|
46 |
|
47 |
if __name__ == "__main__":
|
48 |
|
49 |
-
parser = argparse.ArgumentParser(description="Scraping text from
|
50 |
parser.add_argument(
|
51 |
"-i",
|
52 |
"--tsv_input_file",
|
|
|
46 |
|
47 |
if __name__ == "__main__":
|
48 |
|
49 |
+
parser = argparse.ArgumentParser(description="Scraping text from URLs.")
|
50 |
parser.add_argument(
|
51 |
"-i",
|
52 |
"--tsv_input_file",
|