dtruong46me
commited on
Commit
•
97e4014
1
Parent(s):
559114d
Upload 29 files
Browse files- .DS_Store +0 -0
- LICENSE +21 -0
- README.md +109 -12
- app.py +73 -0
- assets/distribution.png +0 -0
- assets/hist_dialogue+summary.png +0 -0
- assets/hist_dialogue.png +0 -0
- assets/hist_summary.png +0 -0
- assets/image2.png +0 -0
- gen_summary.py +66 -0
- requirements.txt +14 -0
- results/.gitignore +0 -0
- results/rouge_score.csv +9 -0
- run_evaluation.py +77 -0
- run_training.py +39 -0
- setup.sh +8 -0
- src/.DS_Store +0 -0
- src/data/create_dataset.py +115 -0
- src/data/ingest_data.py +113 -0
- src/data/merge_dataset.py +41 -0
- src/data/preprocessing.py +113 -0
- src/evaluate/evaluation.py +81 -0
- src/evaluate/rouge_metric.py +52 -0
- src/model/model.py +80 -0
- src/pipelines/deploy_pipeline.py +0 -0
- src/pipelines/training_pipeline.py +168 -0
- src/test/test_rouge.py +0 -0
- src/utils.py +131 -0
- test_streaming.py +76 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Dinh Truong Phan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,109 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Problem Description
|
2 |
+
|
3 |
+
This project aims to develop a system capable of automatically **summarizing short dialogue text**. This addresses the challenge of extracting concise yet informative summaries from conversational exchanges, enabling users to **quickly grasp the information of the dialogues**.
|
4 |
+
|
5 |
+
Summarizing these conversations can be valuable for various applications, such as:
|
6 |
+
- Streamlining information retrieval in customer service interactions
|
7 |
+
- Condensing meeting discussions for efficient review
|
8 |
+
- Providing concise overviews of chat conversations on social media platforms
|
9 |
+
|
10 |
+
This project tackles the task of automatically generating concise summaries, saving users time and effort while improving comprehension.
|
11 |
+
|
12 |
+
![](assets/image2.png)
|
13 |
+
|
14 |
+
<p align="center"><i>Source: Google Research</i></p>
|
15 |
+
|
16 |
+
**Input:** Dialogue text
|
17 |
+
|
18 |
+
Example:
|
19 |
+
```
|
20 |
+
Matt: Do you want to go for date?
|
21 |
+
Agnes: Wow! You caught me out with this question Matt.
|
22 |
+
...
|
23 |
+
Agnes: See you on saturday.
|
24 |
+
Matt: Yes, looking forward to it.
|
25 |
+
Agnes: Me too.
|
26 |
+
```
|
27 |
+
|
28 |
+
**Output:** Summarized dialogue
|
29 |
+
|
30 |
+
Example:
|
31 |
+
```
|
32 |
+
Matt invites Agnes for a date to get to know each other better. They'll go to the Georgian restaurant in Kazimierz on Saturday at 6 pm, and he'll pick her up on the way to the place.
|
33 |
+
```
|
34 |
+
|
35 |
+
# Dataset
|
36 |
+
|
37 |
+
We'll utilize the `DialogSum` dataset accessible from 🤗**Hugging Face** (https://huggingface.co/datasets/knkarthick/dialogsum) and **Paper** (https://arxiv.org/pdf/2105.06762.pdf). This dataset comprises real-life dialogue scenarios paired with corresponding manually crafted summaries and dialogue topics.
|
38 |
+
|
39 |
+
`DialogSum` is a large-scale dialogue summarization dataset, consisting of **13,460** (Plus 100 holdout data for topic generation) dialogues with corresponding manually labeled summaries and topics.
|
40 |
+
|
41 |
+
Here's a sample of the `DialogSum` dataset structure:
|
42 |
+
|
43 |
+
|
44 |
+
|id|dialogue|summary|topic|
|
45 |
+
|-|-|-|-|
|
46 |
+
|train_3|#Person1#: Why didn't you tell me you had a girlfriend? #Person2#: Sorry, I thought you knew. ... #Person1#: Oh, you men! You are all the same.|#Person1#'s angry because #Person2# didn't tell #Person1# that #Person2# had a girlfriend and would marry her.|have a girl friend|
|
47 |
+
|train_16|#Person1#: Tell me something about your Valentine's Day. ...#Person2#: Yeah, that is what the holiday is for, isn't it?|#Person2# tells #Person1# their Valentine's Day. #Person1# feels it's romantic.|Valentine's Day|
|
48 |
+
|...|...|...|...|
|
49 |
+
|
50 |
+
**Distribution of dataset**
|
51 |
+
|
52 |
+
|Dialogue|Summary|Dialogue + Summary|
|
53 |
+
|:-:|:-:|:-:|
|
54 |
+
|![](assets/hist_dialogue.png)|![](assets/hist_summary.png)|![](assets/hist_dialogue+summary.png)|
|
55 |
+
|
56 |
+
# Method
|
57 |
+
|
58 |
+
### Pre-trained Language Models:
|
59 |
+
|
60 |
+
This project explores two powerful LLMs well-suited for dialogue summarization:
|
61 |
+
|
62 |
+
- **FLAN-T5:** This model excels at understanding complex relationships within text, making it effective in summarizing the nuances of conversations.
|
63 |
+
- **BART:** This model boasts strong capabilities in text generation tasks, making it adept at generating informative and well-structured summaries.
|
64 |
+
|
65 |
+
### Fine-tuning Techniques:
|
66 |
+
|
67 |
+
To tailor these LLMs specifically for dialogue summarization, we will investigate several fine-tuning approaches:
|
68 |
+
|
69 |
+
- Instruction Fine-tuning
|
70 |
+
- Parameter Efficient Fine Tuning (PEFT)
|
71 |
+
+ Low-Rank Adaptation **(LoRA)**
|
72 |
+
+ Quantized Low-Rank Adaptation **(QLoRA)**
|
73 |
+
|
74 |
+
# Installation
|
75 |
+
|
76 |
+
```
|
77 |
+
!git clone "https://github.com/dtruong46me/dialogue-text-summarization.git"
|
78 |
+
```
|
79 |
+
|
80 |
+
# Contributions
|
81 |
+
|
82 |
+
**Supervisor:** Prof. Le Thanh Huong
|
83 |
+
|
84 |
+
**Student Group:**
|
85 |
+
|
86 |
+
|No.|Name|Student ID|Email|
|
87 |
+
|:-:|-|:-:|-|
|
88 |
+
|1|Phan Dinh Truong (Leader)|20214937|[email protected]|
|
89 |
+
|2|Nguyen Tung Luong|20214913|[email protected]|
|
90 |
+
|3|Vu Tuan Minh|20210597|[email protected]|
|
91 |
+
|4|Hoang Tu Quyen|20214929|[email protected]|
|
92 |
+
|
93 |
+
# [Bonus] How to run Streamlit on Kaggle
|
94 |
+
|
95 |
+
```
|
96 |
+
!pip install -q streamlit
|
97 |
+
```
|
98 |
+
|
99 |
+
```
|
100 |
+
!wget -q -O - ipv4.icanhazip.com
|
101 |
+
```
|
102 |
+
|
103 |
+
```
|
104 |
+
!npm install -g localtunnel -q
|
105 |
+
```
|
106 |
+
|
107 |
+
```
|
108 |
+
!streamlit run "/kaggle/working/dialogue-text-summarization/streamlit_app.py" & npx localtunnel --port 8501
|
109 |
+
```
|
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
+
import torch
|
5 |
+
import time
|
6 |
+
|
7 |
+
import sys, os
|
8 |
+
|
9 |
+
path = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
sys.path.insert(0, path)
|
11 |
+
|
12 |
+
from gen_summary import generate_summary
|
13 |
+
|
14 |
+
|
15 |
+
st.title("Dialogue Text Summarization")
|
16 |
+
st.caption("Natural Language Processing Project 20232")
|
17 |
+
|
18 |
+
st.write("---")
|
19 |
+
|
20 |
+
with st.sidebar:
|
21 |
+
checkpoint = st.selectbox("Model", options=[
|
22 |
+
"Choose model",
|
23 |
+
"dtruong46me/train-bart-base",
|
24 |
+
"dtruong46me/flant5-small",
|
25 |
+
"dtruong46me/flant5-base",
|
26 |
+
"dtruong46me/flan-t5-s",
|
27 |
+
"ntluongg/bart-base-luong"
|
28 |
+
])
|
29 |
+
st.button("Model detail", use_container_width=True)
|
30 |
+
st.write("-----")
|
31 |
+
st.write("**Generate Options:**")
|
32 |
+
min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
|
33 |
+
max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
|
34 |
+
temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
|
35 |
+
top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
|
36 |
+
top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
|
37 |
+
|
38 |
+
|
39 |
+
height = 200
|
40 |
+
|
41 |
+
input_text = st.text_area("Dialogue", height=height)
|
42 |
+
|
43 |
+
generation_config = GenerationConfig(
|
44 |
+
min_new_tokens=min_new_tokens,
|
45 |
+
max_new_tokens=320,
|
46 |
+
temperature=temperature,
|
47 |
+
top_p=top_p,
|
48 |
+
top_k=top_k
|
49 |
+
)
|
50 |
+
|
51 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
+
|
53 |
+
if checkpoint=="Choose model":
|
54 |
+
tokenizer = None
|
55 |
+
model = None
|
56 |
+
|
57 |
+
if checkpoint!="Choose model":
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
59 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
if st.button("Submit"):
|
64 |
+
st.write("---")
|
65 |
+
st.write("## Summary")
|
66 |
+
|
67 |
+
if checkpoint=="Choose model":
|
68 |
+
st.error("Please selece a model!")
|
69 |
+
|
70 |
+
else:
|
71 |
+
if input_text=="":
|
72 |
+
st.error("Please enter a dialogue!")
|
73 |
+
st.write(generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer))
|
assets/distribution.png
ADDED
assets/hist_dialogue+summary.png
ADDED
assets/hist_dialogue.png
ADDED
assets/hist_summary.png
ADDED
assets/image2.png
ADDED
gen_summary.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, GenerationConfig, TextStreamer, AutoModelForSeq2SeqLM
|
3 |
+
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
+
# = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logging.basicConfig(
|
12 |
+
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
13 |
+
datefmt = "%m/%d/%Y %H:%M:%S",
|
14 |
+
level = logging.INFO,
|
15 |
+
)
|
16 |
+
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
|
17 |
+
|
18 |
+
def generate_summary(model, input_text, generation_config, tokenizer, st_container=None) -> str:
|
19 |
+
|
20 |
+
try:
|
21 |
+
prefix = "Summarize the following conversation: \n###\n"
|
22 |
+
suffix = "\n### Summary:"
|
23 |
+
|
24 |
+
input_ids = tokenizer.encode(prefix + input_text + "The generated summary should be around " + str(0.15*len(input_text)) + " words." + suffix, return_tensors="pt")
|
25 |
+
output_ids = model.generate(input_ids, do_sample=True, generation_config=generation_config)
|
26 |
+
|
27 |
+
if "bart" in model.name_or_path and model.name_or_path != "dtruong46me/bart-base-qds":
|
28 |
+
output_ids[0][1] = 2
|
29 |
+
|
30 |
+
# streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
31 |
+
# model.generate(input_ids, streamer=streamer, do_sample=True, decoder_start_token_id=2, generation_config=generation_config)
|
32 |
+
# logger.info("\nComplete generate summary!")
|
33 |
+
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
34 |
+
return output_text
|
35 |
+
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error while generating: {e}")
|
38 |
+
raise e
|
39 |
+
|
40 |
+
if __name__=="__main__":
|
41 |
+
input = "#Person1#: Ms. Dawson, I need you to take a dictation for me. #Person2#: Yes, sir... #Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready? #Person2#: Yes, sir. Go ahead. #Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited. #Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications? #Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications. #Person2#: But sir, many employees use Instant Messaging to communicate with their clients. #Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with the memo. Where were we? #Person2#: This applies to internal and external communications. #Person1#: Yes. Any employee who persists in using Instant Messaging will first receive a warning and be placed on probation. At second offense, the employee will face termination. Any questions regarding this new policy may be directed to department heads. #Person2#: Is that all? #Person1#: Yes. Please get this memo typed up and distributed to all employees before 4 pm."
|
42 |
+
target1 = "Ms. Dawson helps #Person1# to write a memo to inform every employee that they have to change the communication method and should not use Instant Messaging anymore."
|
43 |
+
target2 = "In order to prevent employees from wasting time on Instant Message programs, #Person1# decides to terminate the use of those programs and asks Ms. Dawson to send out a memo to all employees by the afternoon."
|
44 |
+
target3 = "Ms. Dawson takes a dictation for #Person1# about prohibiting the use of Instant Message programs in the office. They argue about its reasonability but #Person1# still insists."
|
45 |
+
|
46 |
+
generation_config = GenerationConfig(
|
47 |
+
min_new_tokens=10,
|
48 |
+
max_new_tokens=256,
|
49 |
+
temperature=0.9,
|
50 |
+
top_p=1.0,
|
51 |
+
top_k=50
|
52 |
+
)
|
53 |
+
|
54 |
+
checkpoint = "dtruong46me/bart-base-qds2"
|
55 |
+
|
56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
58 |
+
|
59 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
|
60 |
+
|
61 |
+
generate_summary(model, input, generation_config, tokenizer)
|
62 |
+
print("\n==============\n")
|
63 |
+
|
64 |
+
print("Human base line:\n", target1, end="\n\n")
|
65 |
+
print("Human base line:\n", target2, end="\n\n")
|
66 |
+
print("Human base line:\n", target3, end="\n\n")
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
huggingface_hub
|
3 |
+
nltk
|
4 |
+
numpy
|
5 |
+
pandas
|
6 |
+
peft
|
7 |
+
replicate
|
8 |
+
streamlit
|
9 |
+
torch
|
10 |
+
transformers==4.36.1
|
11 |
+
wandb
|
12 |
+
evaluate
|
13 |
+
rouge_score
|
14 |
+
bert_score
|
results/.gitignore
ADDED
File without changes
|
results/rouge_score.csv
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rouge1,rouge2,rougeL,rougeLsum,gen_len,checkpoint
|
2 |
+
0.39233350039050524,0.1331263872944557,0.30561232240272806,0.305581876074012,25.568,dtruong46me/flant5-small
|
3 |
+
0.42773411047439297,0.16070313389865537,0.33964372087731554,0.33971528751465496,24.633333333333333,dtruong46me/flant5-base
|
4 |
+
0.4436612424628238,0.18215770435271772,0.3574836391515892,0.3575112795473217,25.358,dtruong46me/train-bart-base
|
5 |
+
0.44596490799011734,0.1791041702437794,0.36099829444161424,0.3612203644902555,18.72,dtruong46me/bart-base-instructds2
|
6 |
+
0.5335,0.2672,0.5084,0,0,human-annotated-summary
|
7 |
+
0.4728,0.2118,0.4483,0,0,bart-large-in-paper
|
8 |
+
0.5165,0.2981,0.4336,0.4337,23.187,dtruong46me/bart-base-qds
|
9 |
+
0.4061788843274445,0.1588224274185049,0.3175643149646888,0.3207910509892517,26.058,dtruong46me/flan-t5-s
|
run_evaluation.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings("ignore")
|
3 |
+
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
import os, sys
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
path = os.path.abspath(os.path.join(os.path.dirname(__file__)))
|
12 |
+
sys.path.insert(0, path)
|
13 |
+
|
14 |
+
from src.model.model import load_model
|
15 |
+
from src.evaluate.evaluation import evaluation_rouge
|
16 |
+
from transformers import GenerationConfig
|
17 |
+
|
18 |
+
|
19 |
+
def save_metrics_to_csv(results, resultpath, checkpoint):
|
20 |
+
|
21 |
+
results["checkpoint"] = checkpoint
|
22 |
+
|
23 |
+
# Convert results to DataFrame
|
24 |
+
df = pd.DataFrame([results])
|
25 |
+
|
26 |
+
if not os.path.isfile(resultpath):
|
27 |
+
df.to_csv(resultpath, index=False)
|
28 |
+
else:
|
29 |
+
df.to_csv(resultpath, mode='a', header=False, index=False)
|
30 |
+
|
31 |
+
|
32 |
+
def main():
|
33 |
+
parser = argparse.ArgumentParser(description="Evaluation metric")
|
34 |
+
parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
|
35 |
+
parser.add_argument("--checkpoint", type=str, default="google/flan-t5-base")
|
36 |
+
parser.add_argument("--resultpath", type=str, default="results/rouge_score.csv")
|
37 |
+
|
38 |
+
parser.add_argument("--min_new_tokens", type=int, default=10)
|
39 |
+
parser.add_argument("--max_new_tokens", type=int, default=256)
|
40 |
+
parser.add_argument("--temperature", type=float, default=0.9)
|
41 |
+
parser.add_argument("--top_p", type=float, default=1.0)
|
42 |
+
parser.add_argument("--top_k", type=int, default=50)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
print("=========================================")
|
47 |
+
print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
|
48 |
+
print("=========================================")
|
49 |
+
|
50 |
+
datapath = args.datapath
|
51 |
+
checkpoint = args.checkpoint
|
52 |
+
|
53 |
+
generation_config = GenerationConfig(
|
54 |
+
min_new_tokens=args.min_new_tokens,
|
55 |
+
max_new_tokens=args.max_new_tokens,
|
56 |
+
temperature=args.temperature,
|
57 |
+
top_p=args.top_p,
|
58 |
+
top_k=args.top_k
|
59 |
+
)
|
60 |
+
|
61 |
+
data = load_dataset("binwang/InstructDS_datasets", "DialogSum", split="test")
|
62 |
+
|
63 |
+
model = load_model(checkpoint)
|
64 |
+
print(f"Loaded model from: {checkpoint}")
|
65 |
+
|
66 |
+
results = evaluation_rouge(model, data, generation_config)
|
67 |
+
|
68 |
+
print("--------------------------")
|
69 |
+
for k, v in results.items():
|
70 |
+
print(f"{k}: {v}")
|
71 |
+
print("--------------------------")
|
72 |
+
|
73 |
+
save_metrics_to_csv(results, args.resultpath, checkpoint)
|
74 |
+
print(f"Results saved to: {args.resultpath}")
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|
run_training.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
from huggingface_hub import login
|
3 |
+
|
4 |
+
import warnings
|
5 |
+
warnings.filterwarnings("ignore")
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
|
10 |
+
path = os.path.abspath(os.path.join(os.path.dirname(__file__)))
|
11 |
+
sys.path.insert(0, path)
|
12 |
+
|
13 |
+
from src.pipelines.training_pipeline import training_pipeline
|
14 |
+
from src.utils import parse_args
|
15 |
+
|
16 |
+
def main():
|
17 |
+
# Load argument parser
|
18 |
+
args = parse_args()
|
19 |
+
print(f"\033[92mLoaded argument parsers\033[00m")
|
20 |
+
|
21 |
+
# Load token ID
|
22 |
+
huggingface_hub_token = args.huggingface_hub_token
|
23 |
+
wandb_token = args.wandb_token
|
24 |
+
|
25 |
+
if wandb_token:
|
26 |
+
os.environ["WANDB_PROJECT"] = "nlp_project"
|
27 |
+
|
28 |
+
# Login to Huggingface Hub and WandB
|
29 |
+
login(token=huggingface_hub_token)
|
30 |
+
print("\033[92mSuccessful login to Huggingface Hub\033[00m")
|
31 |
+
|
32 |
+
wandb.login(key=wandb_token)
|
33 |
+
print("\033[92mSuccessful login to WandB\033[00m")
|
34 |
+
|
35 |
+
training_pipeline(args)
|
36 |
+
print("\033[92mFinish training pipeline\033[00m")
|
37 |
+
|
38 |
+
if __name__=='__main__':
|
39 |
+
main()
|
setup.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
echo "Hello"
|
2 |
+
echo "..."
|
3 |
+
pip install -q --upgrade pip
|
4 |
+
pip install -q -U datasets
|
5 |
+
pip install -q transformers
|
6 |
+
pip install -q -r "/kaggle/working/dialogue-text-summarization/requirements.txt"
|
7 |
+
echo "---------"
|
8 |
+
echo "Set up complete!"
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/data/create_dataset.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import sys, os
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from bert_score import BERTScorer
|
7 |
+
|
8 |
+
from transformers import (
|
9 |
+
T5Tokenizer,
|
10 |
+
T5ForConditionalGeneration,
|
11 |
+
AutoTokenizer
|
12 |
+
)
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
from huggingface_hub import login
|
18 |
+
|
19 |
+
from datasets import load_dataset, Dataset
|
20 |
+
|
21 |
+
path = os.path.abspath(os.path.dirname(__file__))
|
22 |
+
sys.path.insert(0, path)
|
23 |
+
|
24 |
+
from preprocessing import *
|
25 |
+
|
26 |
+
def create_qds_triplet(datapath, split, start_index, end_index) -> Dataset:
|
27 |
+
data = load_dataset(datapath, split=split)
|
28 |
+
data = Dataset.from_dict(data[start_index:end_index])
|
29 |
+
|
30 |
+
scorer = BERTScorer(lang="en", rescale_with_baseline=True)
|
31 |
+
|
32 |
+
CHECKPOINT = "google/flan-t5-large"
|
33 |
+
tokenizer = T5Tokenizer.from_pretrained(CHECKPOINT)
|
34 |
+
model = T5ForConditionalGeneration.from_pretrained(CHECKPOINT)
|
35 |
+
|
36 |
+
qds_triplet = {
|
37 |
+
"query": [],
|
38 |
+
"dialogue": [],
|
39 |
+
"summary": []
|
40 |
+
}
|
41 |
+
|
42 |
+
dsp = DialogSumDataset(
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
|
44 |
+
)
|
45 |
+
|
46 |
+
for dialogue, summary in zip(data["dialogue"], data["summary"]):
|
47 |
+
answerable_queries = []
|
48 |
+
|
49 |
+
while len(answerable_queries) < 1:
|
50 |
+
queries = dsp.generate_queries(model, tokenizer, summary, num_queries=5)
|
51 |
+
|
52 |
+
for query in queries:
|
53 |
+
## Text based filtering
|
54 |
+
output = dsp.text_based_filtering(model, tokenizer, query, summary)
|
55 |
+
if "yes" in output.lower():
|
56 |
+
answerable_queries.append(query)
|
57 |
+
|
58 |
+
n = len(answerable_queries)
|
59 |
+
print("Length of answerable queries:", n, end=" ### ")
|
60 |
+
|
61 |
+
if n == 1:
|
62 |
+
qds_triplet["query"].append(answerable_queries[0])
|
63 |
+
qds_triplet["dialogue"].append(dialogue)
|
64 |
+
qds_triplet["summary"].append(summary)
|
65 |
+
|
66 |
+
if n > 1:
|
67 |
+
filtered_queries = []
|
68 |
+
scores = [[0.0]*n for _ in range(n)]
|
69 |
+
|
70 |
+
for i in range(n):
|
71 |
+
for j in range(n):
|
72 |
+
if i > j:
|
73 |
+
scores[i][j] = dsp.semantic_filtering(scorer, answerable_queries[i], answerable_queries[j])
|
74 |
+
|
75 |
+
keep_indices = set(range(n))
|
76 |
+
for i in range(n):
|
77 |
+
for j in range(n):
|
78 |
+
if scores[i][j] > 0.7 and i > j:
|
79 |
+
keep_indices.discard(j)
|
80 |
+
|
81 |
+
for i in sorted(keep_indices):
|
82 |
+
filtered_queries.append(answerable_queries[i])
|
83 |
+
|
84 |
+
print("Length of filtered queries:", len(filtered_queries), end=" ### ")
|
85 |
+
|
86 |
+
for query in filtered_queries:
|
87 |
+
qds_triplet["query"].append(query)
|
88 |
+
qds_triplet["dialogue"].append(dialogue)
|
89 |
+
qds_triplet["summary"].append(summary)
|
90 |
+
|
91 |
+
print("Length of inputs:", len(qds_triplet["summary"]))
|
92 |
+
|
93 |
+
return Dataset.from_dict(qds_triplet)
|
94 |
+
|
95 |
+
if __name__=="__main__":
|
96 |
+
parser = argparse.ArgumentParser()
|
97 |
+
parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
|
98 |
+
parser.add_argument("--huggingface_hub_token", type=str, default="")
|
99 |
+
parser.add_argument("--split", type=str, default="train")
|
100 |
+
parser.add_argument("--start_index", type=int, default=0)
|
101 |
+
parser.add_argument("--end_index", type=int, default=-1)
|
102 |
+
args = parser.parse_args()
|
103 |
+
|
104 |
+
print("=========================================")
|
105 |
+
print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
|
106 |
+
print("=========================================")
|
107 |
+
|
108 |
+
login(token=args.huggingface_hub_token)
|
109 |
+
print("Successfully logged in to Huggingface Hub")
|
110 |
+
|
111 |
+
qds_triplet = create_qds_triplet(args.datapath, args.split, args.start_index, args.end_index)
|
112 |
+
|
113 |
+
save_name = f"dialogsum-{args.split}-{args.start_index}-{args.end_index}"
|
114 |
+
qds_triplet.push_to_hub(save_name)
|
115 |
+
print(f"Saved to: {save_name}")
|
src/data/ingest_data.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from datasets import load_dataset
|
3 |
+
from datasets import DatasetDict, Dataset
|
4 |
+
import random
|
5 |
+
from transformers import set_seed
|
6 |
+
|
7 |
+
|
8 |
+
def ingest_data(datapath: str) -> DatasetDict:
|
9 |
+
set_seed(42)
|
10 |
+
|
11 |
+
QDS_LIMIT = 6000
|
12 |
+
if "," in datapath:
|
13 |
+
datapaths = datapath.split(",")
|
14 |
+
|
15 |
+
datapath1 = "binwang/InstructDS_datasets"
|
16 |
+
datapath2 = "binwang/InstructDS_datasets"
|
17 |
+
|
18 |
+
all_train_data = []
|
19 |
+
origin_train_dialogsum = load_dataset(datapath1, "DialogSum", split="train")
|
20 |
+
qds_dialogsum = load_dataset(datapath2, "DialogSum_QDS", split="train")
|
21 |
+
|
22 |
+
new_data1 = []
|
23 |
+
for sample in origin_train_dialogsum:
|
24 |
+
new_sample = {
|
25 |
+
"instruction": "Please summarize the following dialogue.",
|
26 |
+
"input": sample["dialogue"],
|
27 |
+
"output": sample["summary"]
|
28 |
+
}
|
29 |
+
new_data1.append(new_sample)
|
30 |
+
origin_train_dialogsum = new_data1
|
31 |
+
all_train_data.extend(origin_train_dialogsum)
|
32 |
+
|
33 |
+
print("Len of origin_train_dialogsum: ", len(origin_train_dialogsum))
|
34 |
+
print("Len of all train data 1: ", len(all_train_data))
|
35 |
+
|
36 |
+
new_data2 = []
|
37 |
+
for sample in qds_dialogsum:
|
38 |
+
new_sample = {
|
39 |
+
"instruction": "Please answer the following question.",
|
40 |
+
"input": sample["dialogue"],
|
41 |
+
"output": sample["summary"]
|
42 |
+
}
|
43 |
+
new_data2.append(new_sample)
|
44 |
+
qds_dialogsum = new_data2
|
45 |
+
qds_dialogsum = random.sample(qds_dialogsum, QDS_LIMIT)
|
46 |
+
all_train_data.extend(qds_dialogsum)
|
47 |
+
print("Len of all train data 2: ", len(all_train_data))
|
48 |
+
|
49 |
+
|
50 |
+
naive_all_train_data_dict = {
|
51 |
+
"instruction": [item["instruction"] for item in all_train_data],
|
52 |
+
"input": [item["input"] for item in all_train_data],
|
53 |
+
"output": [item["output"] for item in all_train_data]
|
54 |
+
}
|
55 |
+
|
56 |
+
print("Len of naive_all_train_data_dict: ", len(naive_all_train_data_dict["instruction"]))
|
57 |
+
|
58 |
+
subset_train_data = all_train_data
|
59 |
+
with_len_train_data_dict = {
|
60 |
+
"instruction": [item["instruction"] + f" The output should be {len(item['output'].split())} words long." for item in subset_train_data],
|
61 |
+
"input": [item["input"] for item in subset_train_data],
|
62 |
+
"output": [item["output"] for item in subset_train_data]
|
63 |
+
}
|
64 |
+
|
65 |
+
print("Len of with_len_train_data_dict: ", len(with_len_train_data_dict["instruction"]))
|
66 |
+
|
67 |
+
all_train_data_dict = {
|
68 |
+
"instruction": naive_all_train_data_dict["instruction"] + with_len_train_data_dict["instruction"],
|
69 |
+
"input": naive_all_train_data_dict["input"] + with_len_train_data_dict["input"],
|
70 |
+
"output": naive_all_train_data_dict["output"] + with_len_train_data_dict["output"]
|
71 |
+
}
|
72 |
+
|
73 |
+
print("Len of all_train_data_dict: ", len(all_train_data_dict["instruction"]))
|
74 |
+
|
75 |
+
raw_train_data = Dataset.from_dict(all_train_data_dict)
|
76 |
+
train_data = raw_train_data.shuffle()
|
77 |
+
|
78 |
+
print(type(train_data))
|
79 |
+
print(train_data["instruction"][:10])
|
80 |
+
print(train_data["input"][:10])
|
81 |
+
print(train_data["output"][:10])
|
82 |
+
|
83 |
+
print("===================", len(train_data), "===================")
|
84 |
+
|
85 |
+
# Validation data
|
86 |
+
all_validation_data = []
|
87 |
+
origin_validation_dialogsum = load_dataset(datapath1, "DialogSum", split="validation")
|
88 |
+
|
89 |
+
new_data1 = []
|
90 |
+
for sample in origin_validation_dialogsum:
|
91 |
+
new_sample = {
|
92 |
+
"instruction": "Please summarize the following dialogue.",
|
93 |
+
"input": sample["dialogue"],
|
94 |
+
"output": sample["summary"]
|
95 |
+
}
|
96 |
+
new_data1.append(new_sample)
|
97 |
+
|
98 |
+
origin_validation_dialogsum = new_data1
|
99 |
+
all_validation_data.extend(origin_validation_dialogsum)
|
100 |
+
|
101 |
+
all_validation_data_dict = {
|
102 |
+
"instruction": [item["instruction"] for item in all_validation_data],
|
103 |
+
"input": [item["input"] for item in all_validation_data],
|
104 |
+
"output": [item["output"] for item in all_validation_data]
|
105 |
+
}
|
106 |
+
|
107 |
+
raw_validation_data = Dataset.from_dict(all_validation_data_dict)
|
108 |
+
validation_data = raw_validation_data.shuffle()
|
109 |
+
|
110 |
+
return DatasetDict({
|
111 |
+
"train": train_data,
|
112 |
+
"validation": validation_data
|
113 |
+
})
|
src/data/merge_dataset.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os, sys
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from datasets import load_dataset, concatenate_datasets, Dataset
|
7 |
+
from huggingface_hub import login
|
8 |
+
|
9 |
+
path = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
sys.path.insert(0, path)
|
11 |
+
|
12 |
+
def merge_dataset(datapaths) -> Dataset:
|
13 |
+
datapaths = datapaths.split(",")
|
14 |
+
dataset = load_dataset(datapaths[0], split="train")
|
15 |
+
|
16 |
+
for i in range(1, len(datapaths)):
|
17 |
+
data = load_dataset(datapaths[i], split="train")
|
18 |
+
data = concatenate_datasets([dataset, data])
|
19 |
+
|
20 |
+
return dataset
|
21 |
+
|
22 |
+
|
23 |
+
if __name__=="__main__":
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--datapaths", type=str, default="")
|
26 |
+
parser.add_argument("--huggingface_hub_token", type=str, default="")
|
27 |
+
parser.add_argument("--split", type=str, default="train")
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
print("=========================================")
|
31 |
+
print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
|
32 |
+
print("=========================================")
|
33 |
+
|
34 |
+
login(token=args.huggingface_hub_token)
|
35 |
+
print("Successfully logged in to Huggingface Hub")
|
36 |
+
|
37 |
+
dataset = merge_dataset(datapaths=args.datapaths)
|
38 |
+
|
39 |
+
DATASET_ID = "qds-triplet-dialogsum"
|
40 |
+
dataset.push_to_hub(DATASET_ID)
|
41 |
+
print(f"Successful push to Huggingface Hub: {DATASET_ID}")
|
src/data/preprocessing.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from datasets import DatasetDict, Dataset
|
3 |
+
import random
|
4 |
+
from bert_score import BERTScorer
|
5 |
+
|
6 |
+
from transformers import (
|
7 |
+
T5Tokenizer,
|
8 |
+
T5ForConditionalGeneration
|
9 |
+
)
|
10 |
+
|
11 |
+
class DialogSumDataset:
|
12 |
+
def __init__(self, tokenizer, use_contrastive_loss=False, tokenizing_strategy=1) -> None:
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.use_contrastive_loss = use_contrastive_loss
|
15 |
+
self.tokenizing_strategy = tokenizing_strategy
|
16 |
+
|
17 |
+
def handle_data(self, data: DatasetDict) -> DatasetDict:
|
18 |
+
try:
|
19 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
20 |
+
tokenized_dataset = data.map(self.preprocess_function, batched=True)
|
21 |
+
tokenized_dataset = tokenized_dataset.remove_columns([key for key in data["train"][0].keys()])
|
22 |
+
|
23 |
+
print("+++++++++++++++++++")
|
24 |
+
print(tokenized_dataset)
|
25 |
+
print("+++++++++++++++++++")
|
26 |
+
|
27 |
+
return tokenized_dataset
|
28 |
+
|
29 |
+
except Exception as e:
|
30 |
+
print(f"\033[31m\nError while tokenizing data: {e}\033[00m")
|
31 |
+
raise e
|
32 |
+
|
33 |
+
def preprocess_function(self, data: Dataset) -> Dataset:
|
34 |
+
###
|
35 |
+
if self.tokenizing_strategy<=2:
|
36 |
+
prefix = "Summarize the following conversation:\n###\n"
|
37 |
+
suffix = "\n###\nSummary: "
|
38 |
+
inputs = [prefix + input + suffix for input in data["dialogue"]]
|
39 |
+
targets = data["summary"]
|
40 |
+
|
41 |
+
if self.tokenizing_strategy==1:
|
42 |
+
max_source_length = 1024
|
43 |
+
max_target_length = 176
|
44 |
+
|
45 |
+
if self.tokenizing_strategy==2:
|
46 |
+
max_source_length = 1224
|
47 |
+
max_target_length = 176
|
48 |
+
|
49 |
+
if self.tokenizing_strategy==3:
|
50 |
+
inputs = ["### Instruction: " + instruction + "\n### Input: " + input + "\n### Response: " for instruction, input in zip(data["instruction"], data["input"])]
|
51 |
+
targets = data["output"]
|
52 |
+
|
53 |
+
max_source_length = 1024
|
54 |
+
max_target_length = 176
|
55 |
+
|
56 |
+
data["input_ids"] = self.tokenizer(inputs, max_length=max_source_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
57 |
+
# data["attention_mask"] = self.tokenizer(inputs, max_length=max_source_length, padding="max_length", truncation=True, return_tensors="pt").attention_mask
|
58 |
+
data["labels"] = self.tokenizer(targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
59 |
+
|
60 |
+
# Generate negative examples:
|
61 |
+
if self.use_contrastive_loss==True:
|
62 |
+
negative_summaries = self.generate_negative_examples(data["summary"])
|
63 |
+
data["negative_labels"] = self.tokenizer(negative_summaries, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
|
64 |
+
print("Complete generate negative examples!")
|
65 |
+
|
66 |
+
label_ignore_ids = []
|
67 |
+
for label in data["labels"]:
|
68 |
+
label_example = [l if l != 0 else -100 for l in label]
|
69 |
+
label_ignore_ids.append(label_example)
|
70 |
+
|
71 |
+
data["labels"] = label_ignore_ids
|
72 |
+
|
73 |
+
return data
|
74 |
+
|
75 |
+
## Create Negetive Example for Contrastive Learning
|
76 |
+
def generate_negative_examples(self, summaries):
|
77 |
+
negative_summaries = []
|
78 |
+
for summary in summaries:
|
79 |
+
words = summary.split()
|
80 |
+
random.shuffle(words)
|
81 |
+
negative_summaries.append(" ".join(words))
|
82 |
+
return negative_summaries
|
83 |
+
|
84 |
+
## Create Instruction Dataset
|
85 |
+
def generate_queries(self, model, tokenizer, summary, num_queries):
|
86 |
+
input_text = "Generate an answerable and specific question based on the following context:. ###\nContext: " + summary
|
87 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
88 |
+
outputs = model.generate(input_ids, max_length=64, num_return_sequences=num_queries, do_sample=True)
|
89 |
+
queries = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
90 |
+
return queries
|
91 |
+
|
92 |
+
def text_based_filtering(self, model, tokenizer, query, summary):
|
93 |
+
input_text = "Is the question fully answerable from the context without any guessing, yes or no?###\nQuestion: " + query + "###\nContext: " + summary + "###Answer: "
|
94 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
95 |
+
output_ids = model.generate(input_ids, num_return_sequences=1)
|
96 |
+
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
97 |
+
return output_text
|
98 |
+
|
99 |
+
def semantic_filtering(self, scorer, query1, query2):
|
100 |
+
score = scorer.score([query1], [query2])[0]
|
101 |
+
return score
|
102 |
+
|
103 |
+
|
104 |
+
def preprocessing_data(data: DatasetDict, tokenizer, use_contrastive_loss=False, tokenizing_strategy=False) -> DatasetDict:
|
105 |
+
try:
|
106 |
+
dataset_ds = DialogSumDataset(tokenizer, use_contrastive_loss, tokenizing_strategy)
|
107 |
+
tokenized_data = dataset_ds.handle_data(data)
|
108 |
+
|
109 |
+
return tokenized_data
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f"\nError while pre-processing data: {e}")
|
113 |
+
raise e
|
src/evaluate/evaluation.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
from datasets import Dataset
|
5 |
+
|
6 |
+
import evaluate
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
# = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
logging.basicConfig(
|
14 |
+
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
15 |
+
datefmt = "%m/%d/%Y %H:%M:%S",
|
16 |
+
level = logging.INFO,
|
17 |
+
)
|
18 |
+
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
|
19 |
+
|
20 |
+
from transformers import AutoModelForSeq2SeqLM
|
21 |
+
|
22 |
+
path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
23 |
+
sys.path.insert(0, path)
|
24 |
+
|
25 |
+
from model.model import Model
|
26 |
+
|
27 |
+
|
28 |
+
class RougeEvaluation:
|
29 |
+
def __init__(self) -> None:
|
30 |
+
self.rouge_metric = evaluate.load("rouge")
|
31 |
+
|
32 |
+
def compute_rouge_metric(self, generated_summary, reference_summary) -> dict:
|
33 |
+
results = self.rouge_metric.compute(
|
34 |
+
predictions=generated_summary,
|
35 |
+
references=reference_summary,
|
36 |
+
use_aggregator=True,
|
37 |
+
use_stemmer=True
|
38 |
+
)
|
39 |
+
return results
|
40 |
+
|
41 |
+
|
42 |
+
def evaluation_rouge(model: Model, data: Dataset, generation_config) -> dict:
|
43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
+
model.base_model = model.get_model()
|
45 |
+
|
46 |
+
dialogues = data["dialogue"]
|
47 |
+
|
48 |
+
human_summaries = [summary for summary in data["summary"]]
|
49 |
+
|
50 |
+
model_summaries = []
|
51 |
+
|
52 |
+
prefix = "Summarize the following dialogue:\n###\n"
|
53 |
+
suffix = "\n### Summary: "
|
54 |
+
|
55 |
+
# print("\n******************************")
|
56 |
+
# idx = 0
|
57 |
+
# for answer, dialogue in zip(data["answer"], data["dialogue"]):
|
58 |
+
# prefix = "Please summarize the following dialogue focused on the context query:"
|
59 |
+
# input = prefix + "\n### Queryr: " + answer + "\n### Dialogue: " + dialogue + "\n### The summary should be around " + str(int(0.2*len(dialogue.split()))) + " words." + "\n### Summary: "
|
60 |
+
|
61 |
+
for idx, dialogue in enumerate(dialogues):
|
62 |
+
input = prefix + dialogue + suffix
|
63 |
+
|
64 |
+
print(idx, end="# ")
|
65 |
+
output_text = model.generate_summary(input, generation_config, do_sample=False)
|
66 |
+
|
67 |
+
model_summaries.append(output_text)
|
68 |
+
idx += 1
|
69 |
+
|
70 |
+
logger.info("Evaluating summaries...")
|
71 |
+
|
72 |
+
rouge_evaluator = RougeEvaluation()
|
73 |
+
|
74 |
+
results = rouge_evaluator.compute_rouge_metric(model_summaries, human_summaries)
|
75 |
+
|
76 |
+
generated_lengths = [len(summary.split()) for summary in model_summaries]
|
77 |
+
average_gen_len = sum(generated_lengths) / len(generated_lengths) if generated_lengths else 0
|
78 |
+
|
79 |
+
results["gen_len"] = average_gen_len
|
80 |
+
|
81 |
+
return results
|
src/evaluate/rouge_metric.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
import nltk
|
3 |
+
import numpy as np
|
4 |
+
from nltk.tokenize import sent_tokenize
|
5 |
+
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
12 |
+
sys.path.insert(0, path)
|
13 |
+
|
14 |
+
|
15 |
+
def postprocess_text(preds, labels):
|
16 |
+
nltk.download("punkt")
|
17 |
+
|
18 |
+
preds = [pred.strip() for pred in preds]
|
19 |
+
labels = [label.strip() for label in labels]
|
20 |
+
|
21 |
+
preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
|
22 |
+
labels = ["\n".join(sent_tokenize(label)) for label in labels]
|
23 |
+
|
24 |
+
return preds, labels
|
25 |
+
|
26 |
+
|
27 |
+
def compute_metrics(eval_preds, tokenizer, metric):
|
28 |
+
preds, labels = eval_preds
|
29 |
+
if isinstance(preds, tuple):
|
30 |
+
preds = preds[0]
|
31 |
+
|
32 |
+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
33 |
+
# Replace -100 in the labels as we can't decode them.
|
34 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
35 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
36 |
+
|
37 |
+
# Some simple post-processing
|
38 |
+
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
39 |
+
|
40 |
+
# metric = evaluate.load("rouge")
|
41 |
+
rouge_results = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
42 |
+
rouge_results = {k: round(v * 100, 4) for k, v in rouge_results.items()}
|
43 |
+
|
44 |
+
results = {
|
45 |
+
"rouge1": rouge_results["rouge1"],
|
46 |
+
"rouge2": rouge_results["rouge2"],
|
47 |
+
"rougeL": rouge_results["rougeL"],
|
48 |
+
"rougeLsum": rouge_results["rougeLsum"],
|
49 |
+
"gen_len": np.mean([np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds])
|
50 |
+
}
|
51 |
+
|
52 |
+
return results
|
src/model/model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from transformers import (
|
4 |
+
AutoTokenizer,
|
5 |
+
AutoModelForSeq2SeqLM,
|
6 |
+
)
|
7 |
+
|
8 |
+
from peft import (
|
9 |
+
get_peft_model,
|
10 |
+
)
|
11 |
+
|
12 |
+
class Model:
|
13 |
+
def __init__(self, checkpoint):
|
14 |
+
self.checkpoint = checkpoint
|
15 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
17 |
+
self.base_model = None
|
18 |
+
|
19 |
+
def get_model(self):
|
20 |
+
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
|
21 |
+
|
22 |
+
def get_peft(self, lora_config):
|
23 |
+
return get_peft_model(self.base_model, lora_config)
|
24 |
+
|
25 |
+
def prepare_quantize(self, bnb_config):
|
26 |
+
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint,
|
27 |
+
quantization_config=bnb_config,
|
28 |
+
device_map={"":0},
|
29 |
+
trust_remote_code=True)
|
30 |
+
# self.base_model.gradient_checkpointing_enable()
|
31 |
+
# self.base_model = prepare_model_for_kbit_training(self.base_model)
|
32 |
+
|
33 |
+
|
34 |
+
def generate_summary(self, input_text, generation_config, do_sample=True):
|
35 |
+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
|
36 |
+
output_ids = self.base_model.generate(input_ids=input_ids, do_sample=do_sample, generation_config=generation_config)
|
37 |
+
|
38 |
+
if "bart" in self.checkpoint:
|
39 |
+
output_ids[0][1] = 2
|
40 |
+
|
41 |
+
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
42 |
+
print(f"\033[94mSummary: {output_text}\n\033[00m")
|
43 |
+
return output_text
|
44 |
+
|
45 |
+
class BartSum(Model):
|
46 |
+
def __init__(self, checkpoint):
|
47 |
+
super().__init__(checkpoint)
|
48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
49 |
+
|
50 |
+
def get_model(self):
|
51 |
+
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
|
52 |
+
|
53 |
+
|
54 |
+
class FlanT5Sum(Model):
|
55 |
+
def __init__(self, checkpoint):
|
56 |
+
super().__init__(checkpoint)
|
57 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
58 |
+
|
59 |
+
def get_model(self):
|
60 |
+
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
|
61 |
+
|
62 |
+
|
63 |
+
def load_model(checkpoint):
|
64 |
+
|
65 |
+
try:
|
66 |
+
if "bart" in checkpoint:
|
67 |
+
print(f"\033[92mLoad Bart model from checkpoint: {checkpoint}\033[00m")
|
68 |
+
return BartSum(checkpoint)
|
69 |
+
|
70 |
+
if "flan" in checkpoint:
|
71 |
+
print(f"\033[92mLoad Flan-T5 model from checkpoint: {checkpoint}\033[00m")
|
72 |
+
return FlanT5Sum(checkpoint)
|
73 |
+
|
74 |
+
else:
|
75 |
+
print(f"\033[92mLoad general model from checkpoint: {checkpoint}\033[00m")
|
76 |
+
return Model(checkpoint)
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
print("Error while loading model: {e}")
|
80 |
+
raise e
|
src/pipelines/deploy_pipeline.py
ADDED
File without changes
|
src/pipelines/training_pipeline.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import nltk
|
6 |
+
|
7 |
+
from nltk.tokenize import sent_tokenize
|
8 |
+
from transformers import (
|
9 |
+
Seq2SeqTrainer,
|
10 |
+
AutoTokenizer,
|
11 |
+
AutoModelForSeq2SeqLM
|
12 |
+
)
|
13 |
+
|
14 |
+
from peft import get_peft_model, prepare_model_for_kbit_training
|
15 |
+
|
16 |
+
path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
17 |
+
sys.path.insert(0, path)
|
18 |
+
|
19 |
+
from utils import *
|
20 |
+
|
21 |
+
# from model.models import load_model
|
22 |
+
from model.model import load_model
|
23 |
+
from data.preprocessing import preprocessing_data
|
24 |
+
from data.ingest_data import ingest_data
|
25 |
+
|
26 |
+
import evaluate
|
27 |
+
|
28 |
+
|
29 |
+
def training_pipeline(args: argparse.Namespace):
|
30 |
+
try:
|
31 |
+
print("=========================================")
|
32 |
+
print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
|
33 |
+
print("=========================================")
|
34 |
+
|
35 |
+
import torch
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
model = load_model(args.checkpoint)
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
|
40 |
+
print(type(tokenizer))
|
41 |
+
|
42 |
+
if (args.lora == False):
|
43 |
+
print("lora=Fasle, quantize=False")
|
44 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint).to(device)
|
45 |
+
# model.base_model = model.get_model()
|
46 |
+
# model.base_model.to(device)
|
47 |
+
|
48 |
+
else:
|
49 |
+
from peft import LoraConfig, TaskType
|
50 |
+
from transformers import BitsAndBytesConfig
|
51 |
+
import torch
|
52 |
+
# Define LoRA Config
|
53 |
+
lora_config = LoraConfig(
|
54 |
+
r=args.lora_rank,
|
55 |
+
lora_alpha=args.lora_alpha,
|
56 |
+
target_modules=args.target_modules.split(","),
|
57 |
+
lora_dropout=args.lora_dropout,
|
58 |
+
bias="none",
|
59 |
+
task_type=TaskType.SEQ_2_SEQ_LM
|
60 |
+
)
|
61 |
+
|
62 |
+
if (args.quantize == True):
|
63 |
+
print("Quantize=True, lora=True")
|
64 |
+
bnb_config = BitsAndBytesConfig(
|
65 |
+
load_in_4bit=True,
|
66 |
+
bnb_4bit_use_double_quant=True,
|
67 |
+
bnb_4bit_quant_type="nf4",
|
68 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
69 |
+
)
|
70 |
+
|
71 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint,
|
72 |
+
quantization_config=bnb_config,
|
73 |
+
device_map={"":0},
|
74 |
+
trust_remote_code=True)
|
75 |
+
base_model = prepare_model_for_kbit_training(base_model)
|
76 |
+
|
77 |
+
if (args.quantize==False):
|
78 |
+
print("Quantize=False, lora=True")
|
79 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint).to(device)
|
80 |
+
|
81 |
+
# add LoRA adaptor
|
82 |
+
print("Base model:", model.base_model)
|
83 |
+
base_model = get_peft_model(base_model, lora_config)
|
84 |
+
base_model.print_trainable_parameters()
|
85 |
+
|
86 |
+
|
87 |
+
# Load data from datapath
|
88 |
+
data = ingest_data(args.datapath)
|
89 |
+
print("\033[92m[+] Complete loading dataset!\033[00m")
|
90 |
+
|
91 |
+
# Pre-processing data
|
92 |
+
data = preprocessing_data(data, tokenizer, use_contrastive_loss=args.use_contrastive_loss, tokenizing_strategy=args.tokenizing_strategy)
|
93 |
+
print("\033[92m[+] Complete pre-processing dataset!\033[00m")
|
94 |
+
|
95 |
+
# Load training arguments
|
96 |
+
training_args = load_training_arguments(args)
|
97 |
+
print("\033[92m[+] Complete loading training arguments!\033[00m")
|
98 |
+
|
99 |
+
# Load metric
|
100 |
+
metric = evaluate.load("rouge")
|
101 |
+
nltk.download("punkt")
|
102 |
+
|
103 |
+
def postprocess_text(preds, labels):
|
104 |
+
preds = [pred.strip() for pred in preds]
|
105 |
+
labels = [label.strip() for label in labels]
|
106 |
+
|
107 |
+
preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
|
108 |
+
labels = ["\n".join(sent_tokenize(label)) for label in labels]
|
109 |
+
|
110 |
+
return preds, labels
|
111 |
+
|
112 |
+
def compute_metric(eval_preds):
|
113 |
+
preds, labels = eval_preds
|
114 |
+
if isinstance(preds, tuple):
|
115 |
+
preds = preds[0]
|
116 |
+
|
117 |
+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
118 |
+
# Replace -100 in the labels as we can't decode them.
|
119 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
120 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
121 |
+
|
122 |
+
# Some simple post-processing
|
123 |
+
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
124 |
+
|
125 |
+
# metric = evaluate.load("rouge")
|
126 |
+
rouge_results = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
127 |
+
rouge_results = {k: round(v * 100, 4) for k, v in rouge_results.items()}
|
128 |
+
|
129 |
+
results = {
|
130 |
+
"rouge1": rouge_results["rouge1"],
|
131 |
+
"rouge2": rouge_results["rouge2"],
|
132 |
+
"rougeL": rouge_results["rougeL"],
|
133 |
+
"rougeLsum": rouge_results["rougeLsum"],
|
134 |
+
"gen_len": np.mean([np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds])
|
135 |
+
}
|
136 |
+
|
137 |
+
return results
|
138 |
+
|
139 |
+
# Load trainer
|
140 |
+
if args.use_contrastive_loss==True:
|
141 |
+
trainer = ContrastiveLearningTrainer(model=base_model,
|
142 |
+
train_dataset=data["train"],
|
143 |
+
eval_dataset=data["validation"],
|
144 |
+
tokenizer=tokenizer,
|
145 |
+
compute_metrics=compute_metric)
|
146 |
+
|
147 |
+
if args.use_contrastive_loss==False:
|
148 |
+
trainer = Seq2SeqTrainer(model=base_model,
|
149 |
+
args=training_args,
|
150 |
+
train_dataset=data["train"],
|
151 |
+
eval_dataset=data["validation"],
|
152 |
+
tokenizer=tokenizer,
|
153 |
+
compute_metrics=compute_metric)
|
154 |
+
|
155 |
+
print("\033[92m[+] Complete loading trainer!\033[00m")
|
156 |
+
|
157 |
+
# Train model
|
158 |
+
trainer.train()
|
159 |
+
print("\033[92m[+] Complete training!\033[00m")
|
160 |
+
|
161 |
+
# Push to Huggingface Hub
|
162 |
+
trainer.push_to_hub()
|
163 |
+
print("\033[92m [+] Complete pushing model to hub!\033[00m")
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
print(f"\033[31m\nError while training: {e}\033[00m")
|
167 |
+
raise e
|
168 |
+
|
src/test/test_rouge.py
ADDED
File without changes
|
src/utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from transformers import (
|
10 |
+
Seq2SeqTrainingArguments,
|
11 |
+
Seq2SeqTrainer,
|
12 |
+
)
|
13 |
+
|
14 |
+
path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
15 |
+
sys.path.insert(0, path)
|
16 |
+
|
17 |
+
# from src.evaluate.rouge_metric import compute_metrics
|
18 |
+
|
19 |
+
def parse_args() -> argparse.Namespace:
|
20 |
+
parser = argparse.ArgumentParser(description="Fine tuning LLM for Dialogue Text Summarization")
|
21 |
+
parser.add_argument("--huggingface_hub_token", type=str, default=None)
|
22 |
+
parser.add_argument("--wandb_token", type=str, default=None)
|
23 |
+
|
24 |
+
parser.add_argument("--checkpoint", type=str, default="google/flan-t5-base")
|
25 |
+
parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
|
26 |
+
|
27 |
+
parser.add_argument("--output_dir", type=str, default="fine-tuned-flant5")
|
28 |
+
parser.add_argument("--overwrite_output_dir", action="store_true")
|
29 |
+
|
30 |
+
parser.add_argument("--num_train_epochs", type=int, default=3)
|
31 |
+
parser.add_argument("--per_device_train_batch_size", type=int, default=4)
|
32 |
+
parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
|
33 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
34 |
+
|
35 |
+
parser.add_argument("--learning_rate", type=float, default=0.00005)
|
36 |
+
parser.add_argument("--weight_decay", type=float, default=0.005)
|
37 |
+
|
38 |
+
parser.add_argument("--evaluation_strategy", type=str, default="no")
|
39 |
+
parser.add_argument("--save_strategy", type=str, default="no")
|
40 |
+
|
41 |
+
parser.add_argument("--logging_strategy", type=str, default="steps")
|
42 |
+
parser.add_argument("--logging_steps", type=int, default=1000)
|
43 |
+
parser.add_argument("--save_total_limit", type=int, default=1)
|
44 |
+
|
45 |
+
parser.add_argument("--report_to", type=str, default="wandb")
|
46 |
+
parser.add_argument("--run_name", type=str, default="flan-t5-base-model")
|
47 |
+
|
48 |
+
parser.add_argument("--predict_with_generate", action="store_true")
|
49 |
+
|
50 |
+
parser.add_argument("--min_new_tokens", type=int, default=10)
|
51 |
+
parser.add_argument("--max_new_tokens", type=int, default=256)
|
52 |
+
parser.add_argument("--temperature", type=float, default=0.9)
|
53 |
+
parser.add_argument("--top_p", type=float, default=1.0)
|
54 |
+
parser.add_argument("--top_k", type=int, default=50)
|
55 |
+
|
56 |
+
parser.add_argument("--lora", action="store_true")
|
57 |
+
parser.add_argument("--quantize", action="store_true")
|
58 |
+
|
59 |
+
parser.add_argument("--lora_rank", type=int, default=8)
|
60 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
61 |
+
parser.add_argument("--target_modules", type=str, default="q,v")
|
62 |
+
parser.add_argument("--lora_dropout", type=float, default=0.05)
|
63 |
+
|
64 |
+
parser.add_argument("--use_contrastive_loss", action="store_true")
|
65 |
+
parser.add_argument("--tokenizing_strategy", type=int, default=1)
|
66 |
+
|
67 |
+
args = parser.parse_args()
|
68 |
+
return args
|
69 |
+
|
70 |
+
|
71 |
+
def load_training_arguments(args):
|
72 |
+
try:
|
73 |
+
training_args = Seq2SeqTrainingArguments(
|
74 |
+
output_dir=args.output_dir,
|
75 |
+
overwrite_output_dir=args.overwrite_output_dir,
|
76 |
+
|
77 |
+
num_train_epochs=args.num_train_epochs,
|
78 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
79 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
80 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
81 |
+
|
82 |
+
learning_rate=args.learning_rate,
|
83 |
+
weight_decay=args.weight_decay,
|
84 |
+
|
85 |
+
evaluation_strategy=args.evaluation_strategy,
|
86 |
+
save_strategy=args.save_strategy,
|
87 |
+
|
88 |
+
logging_strategy=args.logging_strategy,
|
89 |
+
logging_steps=args.logging_steps,
|
90 |
+
save_total_limit=args.save_total_limit,
|
91 |
+
|
92 |
+
report_to=args.report_to,
|
93 |
+
run_name=args.run_name,
|
94 |
+
|
95 |
+
predict_with_generate=args.predict_with_generate
|
96 |
+
)
|
97 |
+
|
98 |
+
return training_args
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error while loading training arguments: {e}")
|
102 |
+
raise e
|
103 |
+
|
104 |
+
class ContrastiveLoss(nn.Module):
|
105 |
+
def __init__(self, margin=1.0):
|
106 |
+
super(ContrastiveLoss, self).__init__()
|
107 |
+
self.margin = margin
|
108 |
+
self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6)
|
109 |
+
|
110 |
+
def forward(self, dialgue_embeddings, pos_summary_embeddings, neg_summary_embeddings):
|
111 |
+
pos_sim = self.cosine_similarity(dialgue_embeddings, pos_summary_embeddings)
|
112 |
+
neg_sim = self.cosine_similarity(dialgue_embeddings, neg_summary_embeddings)
|
113 |
+
loss = torch.mean(1-pos_sim) + torch.clamp(neg_sim-self.margin, min=0.0)
|
114 |
+
|
115 |
+
return loss
|
116 |
+
|
117 |
+
class ContrastiveLearningTrainer(Seq2SeqTrainer):
|
118 |
+
def compute_loss(model, inputs, return_outputs=False):
|
119 |
+
output = model(**inputs)
|
120 |
+
lm_loss = output.loss
|
121 |
+
|
122 |
+
dialogue_embeddings = model.encoder(inputs["input_ids"]).last_hidden_state
|
123 |
+
pos_summary_embeddings = model.encoder(inputs["labels"]).last_hidden_state
|
124 |
+
neg_summary_embeddings = model.encoder(inputs["negative_labels"]).last_hidden_state
|
125 |
+
|
126 |
+
contrastive_loss = ContrastiveLoss(margin=1.0)(dialogue_embeddings, pos_summary_embeddings, neg_summary_embeddings)
|
127 |
+
|
128 |
+
# Combine losses
|
129 |
+
total_loss = lm_loss + contrastive_loss
|
130 |
+
|
131 |
+
return (total_loss, output) if return_outputs else total_loss
|
test_streaming.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import replicate
|
3 |
+
import os
|
4 |
+
from transformers import AutoTokenizer, GenerationConfig, AutoModelForSeq2SeqLM
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Set Replicate API token
|
8 |
+
with st.sidebar:
|
9 |
+
st.title('Dialogue Text Summarization')
|
10 |
+
if 'REPLICATE_API_TOKEN' in st.secrets:
|
11 |
+
replicate_api = st.secrets['REPLICATE_API_TOKEN']
|
12 |
+
else:
|
13 |
+
replicate_api = st.text_input('Enter Replicate API token:', type='password')
|
14 |
+
if not (replicate_api.startswith('r8_') and len(replicate_api) == 40):
|
15 |
+
st.warning('Please enter your Replicate API token.', icon='⚠️')
|
16 |
+
st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
|
17 |
+
|
18 |
+
os.environ['REPLICATE_API_TOKEN'] = replicate_api
|
19 |
+
st.subheader("Adjust model parameters")
|
20 |
+
min_new_tokens = st.slider('Min new tokens', min_value=1, max_value=256, step=1, value=10)
|
21 |
+
temperature = st.slider('Temperature', min_value=0.01, max_value=1.00, step=0.01, value=1.0)
|
22 |
+
top_k = st.slider('Top_k', min_value=1, max_value=50, step=1, value=20)
|
23 |
+
top_p = st.slider('Top_p', min_value=0.01, max_value=1.00, step=0.01, value=1.0)
|
24 |
+
|
25 |
+
# Initialize model and tokenizer
|
26 |
+
checkpoint = "dtruong46me/train-bart-base"
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
29 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
|
30 |
+
|
31 |
+
st.title("Dialogue Text Summarization")
|
32 |
+
st.caption("Natural Language Processing Project 20232")
|
33 |
+
st.write("---")
|
34 |
+
|
35 |
+
input_text = st.text_area("Dialogue", height=200)
|
36 |
+
|
37 |
+
generation_config = GenerationConfig(
|
38 |
+
min_new_tokens=min_new_tokens,
|
39 |
+
max_new_tokens=320,
|
40 |
+
temperature=temperature,
|
41 |
+
top_p=top_p,
|
42 |
+
top_k=top_k
|
43 |
+
)
|
44 |
+
|
45 |
+
def generate_summary(model, input_text, generation_config, tokenizer):
|
46 |
+
prefix = "Summarize the following conversation: \n\n###"
|
47 |
+
suffix = "\n\nSummary:"
|
48 |
+
input_ids = tokenizer.encode(prefix + input_text + suffix, return_tensors="pt").to(model.device)
|
49 |
+
prompt_str = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
50 |
+
return prompt_str
|
51 |
+
|
52 |
+
def stream_summary(prompt_str, temperature, top_p):
|
53 |
+
for event in replicate.stream(
|
54 |
+
"snowflake/snowflake-arctic-instruct",
|
55 |
+
input={"prompt": prompt_str,
|
56 |
+
"prompt_template": r"{prompt}",
|
57 |
+
"temperature": temperature,
|
58 |
+
"top_p": top_p}):
|
59 |
+
yield str(event['output'])
|
60 |
+
|
61 |
+
if st.button("Submit"):
|
62 |
+
st.write("---")
|
63 |
+
st.write("## Summary")
|
64 |
+
|
65 |
+
if not replicate_api:
|
66 |
+
st.error("Please enter your Replicate API token!")
|
67 |
+
elif not input_text:
|
68 |
+
st.error("Please enter a dialogue!")
|
69 |
+
else:
|
70 |
+
prompt_str = generate_summary(model, input_text, generation_config, tokenizer)
|
71 |
+
summary_container = st.empty()
|
72 |
+
|
73 |
+
summary_text = ""
|
74 |
+
for output in stream_summary(prompt_str, temperature, top_p):
|
75 |
+
summary_text += output
|
76 |
+
summary_container.text(summary_text)
|