minh21 commited on
Commit
1f4a7d9
1 Parent(s): 45c79cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox, Checkbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel, PeftConfig
5
+ import torch
6
+ import datasets
7
+
8
+ # Load your fine-tuned model and tokenizer
9
+ model_name = "google/flan-t5-large"
10
+ peft_name = "legacy107/flan-t5-large-ia3-cpgQA"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ pretrained_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
13
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
14
+ model = PeftModel.from_pretrained(model, peft_name)
15
+
16
+ peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase"
17
+ peft_config = PeftConfig.from_pretrained(peft_name)
18
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
+ paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name)
20
+
21
+ max_length = 512
22
+ max_target_length = 200
23
+
24
+ # Load your dataset
25
+ dataset = datasets.load_dataset("minh21/cpgQA-v1.0-unique-context-test-10-percent-validation-10-percent", split="test")
26
+ dataset = dataset.shuffle()
27
+ dataset = dataset.select(range(10))
28
+
29
+
30
+ def paraphrase_answer(question, answer, use_pretrained=False):
31
+ # Combine question and context
32
+ input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
33
+
34
+ # Tokenize the input text
35
+ input_ids = tokenizer(
36
+ input_text,
37
+ return_tensors="pt",
38
+ padding="max_length",
39
+ truncation=True,
40
+ max_length=max_length,
41
+ ).input_ids
42
+
43
+ # Generate the answer
44
+ with torch.no_grad():
45
+ if use_pretrained:
46
+ generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
47
+ else:
48
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
49
+
50
+ # Decode and return the generated answer
51
+ paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
52
+
53
+ return paraphrased_answer
54
+
55
+
56
+ # Define your function to generate answers
57
+ def generate_answer(question, context, ground_truth, do_pretrained, do_natural, do_pretrained_natural):
58
+ # Combine question and context
59
+ input_text = f"question: {question} context: {context}"
60
+
61
+ # Tokenize the input text
62
+ input_ids = tokenizer(
63
+ input_text,
64
+ return_tensors="pt",
65
+ padding="max_length",
66
+ truncation=True,
67
+ max_length=max_length,
68
+ ).input_ids
69
+
70
+ # Generate the answer
71
+ with torch.no_grad():
72
+ generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
73
+
74
+ # Decode and return the generated answer
75
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
76
+
77
+ # Paraphrase answer
78
+ paraphrased_answer = ""
79
+ if do_natural:
80
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
81
+
82
+ # Get pretrained model's answer
83
+ pretrained_answer = ""
84
+ if do_pretrained:
85
+ with torch.no_grad():
86
+ pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
87
+ pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
88
+
89
+ # Get pretrained model's natural answer
90
+ pretrained_paraphrased_answer = ""
91
+ if do_pretrained_natural:
92
+ pretrained_paraphrased_answer = paraphrase_answer(question, generated_answer, True)
93
+
94
+ return generated_answer, paraphrased_answer, pretrained_answer, pretrained_paraphrased_answer
95
+
96
+
97
+ # Define a function to list examples from the dataset
98
+ def list_examples():
99
+ examples = []
100
+ for example in dataset:
101
+ context = example["context"]
102
+ question = example["question"]
103
+ answer = example["answer_text"]
104
+ examples.append([question, context, answer, True, True, True])
105
+ return examples
106
+
107
+
108
+ # Create a Gradio interface
109
+ iface = gr.Interface(
110
+ fn=generate_answer,
111
+ inputs=[
112
+ Textbox(label="Question"),
113
+ Textbox(label="Context"),
114
+ Textbox(label="Ground truth"),
115
+ Checkbox(label="Include pretrained model's answer"),
116
+ Checkbox(label="Include natural answer"),
117
+ Checkbox(label="Include pretrained model's natural answer")
118
+ ],
119
+ outputs=[
120
+ Textbox(label="Generated Answer"),
121
+ Textbox(label="Natural Answer"),
122
+ Textbox(label="Pretrained Model's Answer"),
123
+ Textbox(label="Pretrained Model's Natural Answer")
124
+ ],
125
+ examples=list_examples()
126
+ )
127
+
128
+ # Launch the Gradio interface
129
+ iface.launch()