Towhidul commited on
Commit
96d3362
β€’
1 Parent(s): 4ab6a77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -0
app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ import re
4
+
5
+ import requests
6
+
7
+ API_URL = "https://api-inference.huggingface.co/models/microsoft/prophetnet-large-uncased-squad-qg"
8
+ headers = {"Authorization": "Bearer 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"}
9
+
10
+ def query(payload):
11
+ response = requests.post(API_URL, headers=headers, json=payload)
12
+ return response.json()
13
+
14
+
15
+ #-----------------------------------------------------------
16
+
17
+ API_URL_evidence ="https://api-inference.huggingface.co/models/google/flan-t5-xxl"
18
+ headers_evidence = {"Authorization": "Bearer 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"}
19
+
20
+ def query_evidence(payload):
21
+ response = requests.post(API_URL_evidence, headers=headers_evidence, json=payload)
22
+ return response.json()
23
+
24
+ #-----------------------------------------------------------
25
+ claim_text=st.text_area("Enter your claim:")
26
+
27
+ evidence_text=st.text_area("Enter your evidence:")
28
+
29
+ import pandas as pd
30
+ import numpy as np
31
+ from allennlp.predictors.predictor import Predictor
32
+ import allennlp_models.tagging
33
+ predictor = Predictor.from_path("/kaggle/input/vitc-sampled-evidence/structured-prediction-srl-bert")
34
+
35
+ #---------------------------------------------------------------
36
+ def claim(text):
37
+ df = pd.DataFrame({'claim' : [text]})
38
+ def srl_allennlp(sent):
39
+ try:
40
+ #result = predictor.predict(sentence=sent)['verbs'][0]['description']
41
+ #result = predictor.predict(sentence=sent)['verbs'][0]['tags']
42
+ result = predictor.predict(sentence=sent)
43
+ return(result)
44
+ except IndexError:
45
+ pass
46
+ #return(predictor.predict(sentence=sent))
47
+
48
+ df['allennlp_srl'] = df['claim'].apply(lambda x: srl_allennlp(x))
49
+
50
+ df['number_of_verbs'] = ''
51
+ df['verbs_group'] = ''
52
+ df['words'] = ''
53
+ df['verbs'] = ''
54
+ df['modified'] =''
55
+
56
+ col1 = df['allennlp_srl']
57
+ for i in range(len(col1)):
58
+ num_verb = len(col1[i]['verbs'])
59
+ df['number_of_verbs'][i] = num_verb
60
+ df['verbs_group'][i] = col1[i]['verbs']
61
+ df['words'][i] = col1[i]['words']
62
+
63
+ x=[]
64
+ for verb in range(len(col1[i]['verbs'])):
65
+ x.append(col1[i]['verbs'][verb]['verb'])
66
+ df['verbs'][i] = x
67
+
68
+ verb_dict ={}
69
+ desc = []
70
+ for j in range(len(col1[i]['verbs'])):
71
+ string = (col1[i]['verbs'][j]['description'])
72
+ string = string.replace("ARG0", "who")
73
+ string = string.replace("ARG1", "what")
74
+ string = string.replace("ARGM-TMP", "when")
75
+ string = string.replace("ARGM-LOC", "where")
76
+ string = string.replace("ARGM-CAU", "why")
77
+ desc.append(string)
78
+ verb_dict[col1[i]['verbs'][j]['verb']]=string
79
+ df['modified'][i] = verb_dict
80
+
81
+
82
+ #----------FOR COLUMN "WHO"------------#
83
+ df['who'] = ''
84
+ for j in range(len(df['modified'])):
85
+ val_list = []
86
+ val_string = ''
87
+ for k,v in df['modified'][j].items():
88
+ # print(type(v))
89
+ val_list.append(v)
90
+
91
+ who = []
92
+ for indx in range(len(val_list)):
93
+ val_string = val_list[indx]
94
+ pos = val_string.find("who: ")
95
+ substr = ''
96
+
97
+ if pos != -1:
98
+ for i in range(pos+5, len(val_string)):
99
+ if val_string[i] == "]":
100
+ break
101
+ else:
102
+ substr = substr + val_string[i]
103
+ else:
104
+ substr = None
105
+ who.append(substr)
106
+
107
+ df['who'][j] = who
108
+
109
+ #----------FOR COLUMN "WHAT"------------#
110
+ df['what'] = ''
111
+ for j in range(len(df['modified'])):
112
+ val_list = []
113
+ val_string = ''
114
+ for k,v in df['modified'][j].items():
115
+ # print(type(v))
116
+ val_list.append(v)
117
+
118
+ what = []
119
+ for indx in range(len(val_list)):
120
+ val_string = val_list[indx]
121
+ pos = val_string.find("what: ")
122
+ substr = ''
123
+
124
+ if pos != -1:
125
+ for i in range(pos+6, len(val_string)):
126
+ if val_string[i] == "]":
127
+ break
128
+ else:
129
+ substr = substr + val_string[i]
130
+ else:
131
+ substr = None
132
+ what.append(substr)
133
+
134
+ df['what'][j] = what
135
+
136
+ #----------FOR COLUMN "WHY"------------#
137
+ df['why'] = ''
138
+ for j in range(len(df['modified'])):
139
+ val_list = []
140
+ val_string = ''
141
+ for k,v in df['modified'][j].items():
142
+ # print(type(v))
143
+ val_list.append(v)
144
+
145
+ why = []
146
+ for indx in range(len(val_list)):
147
+ val_string = val_list[indx]
148
+ pos = val_string.find("why: ")
149
+ substr = ''
150
+
151
+ if pos != -1:
152
+ for i in range(pos+5, len(val_string)):
153
+ if val_string[i] == "]":
154
+ break
155
+ else:
156
+ substr = substr + val_string[i]
157
+ else:
158
+ substr = None
159
+ why.append(substr)
160
+
161
+ df['why'][j] = why
162
+
163
+ #----------FOR COLUMN "WHEN"------------#
164
+ df['when'] = ''
165
+ for j in range(len(df['modified'])):
166
+ val_list = []
167
+ val_string = ''
168
+ for k,v in df['modified'][j].items():
169
+ # print(type(v))
170
+ val_list.append(v)
171
+
172
+ when = []
173
+ for indx in range(len(val_list)):
174
+ val_string = val_list[indx]
175
+ pos = val_string.find("when: ")
176
+ substr = ''
177
+
178
+ if pos != -1:
179
+ for i in range(pos+6, len(val_string)):
180
+ if val_string[i] == "]":
181
+ break
182
+ else:
183
+ substr = substr + val_string[i]
184
+ else:
185
+ substr = None
186
+ when.append(substr)
187
+
188
+ df['when'][j] = when
189
+
190
+
191
+ #----------FOR COLUMN "WHERE"------------#
192
+ df['where'] = ''
193
+ for j in range(len(df['modified'])):
194
+ val_list = []
195
+ val_string = ''
196
+ for k,v in df['modified'][j].items():
197
+ # print(type(v))
198
+ val_list.append(v)
199
+
200
+ where = []
201
+ for indx in range(len(val_list)):
202
+ val_string = val_list[indx]
203
+ pos = val_string.find("where: ")
204
+ substr = ''
205
+
206
+ if pos != -1:
207
+ for i in range(pos+7, len(val_string)):
208
+ if val_string[i] == "]":
209
+ break
210
+ else:
211
+ substr = substr + val_string[i]
212
+ else:
213
+ substr = None
214
+ where.append(substr)
215
+
216
+ df['where'][j] = where
217
+
218
+ data=df[["claim","who","what","why","when","where"]].copy()
219
+ import re
220
+ def remove_trail_comma(text):
221
+ x = re.sub(",\s*$", "", text)
222
+ return x
223
+
224
+
225
+ data['claim']=data['claim'].apply(lambda x: str(x).replace('\'','').replace('\'',''))
226
+ data['claim']=data['claim'].apply(lambda x: str(x).replace('[','').replace(']',''))
227
+
228
+
229
+
230
+ data['who']=data['who'].apply(lambda x: str(x).replace(" 's","'s"))
231
+ data['who']=data['who'].apply(lambda x: str(x).replace("s ’","s’"))
232
+ data['who']=data['who'].apply(lambda x: str(x).replace(" - ","-"))
233
+ data['who']=data['who'].apply(lambda x: str(x).replace('\'','').replace('\'',''))
234
+ # data['who']=data['who'].apply(lambda x: str(x).replace('"','').replace('"',''))
235
+ data['who']=data['who'].apply(lambda x: str(x).replace('[','').replace(']',''))
236
+ data['who']=data['who'].apply(lambda x: str(x).rstrip(','))
237
+ data['who']=data['who'].apply(lambda x: str(x).lstrip(','))
238
+ data['who']=data['who'].apply(lambda x: str(x).replace('None,','').replace('None',''))
239
+ data['who']=data['who'].apply(remove_trail_comma)
240
+
241
+
242
+
243
+ data['what']=data['what'].apply(lambda x: str(x).replace(" 's","'s"))
244
+ data['what']=data['what'].apply(lambda x: str(x).replace("s ’","s’"))
245
+ data['what']=data['what'].apply(lambda x: str(x).replace(" - ","-"))
246
+ data['what']=data['what'].apply(lambda x: str(x).replace('\'','').replace('\'',''))
247
+ # data['what']=data['what'].apply(lambda x: str(x).replace('"','').replace('"',''))
248
+ data['what']=data['what'].apply(lambda x: str(x).replace('[','').replace(']',''))
249
+ data['what']=data['what'].apply(lambda x: str(x).rstrip(','))
250
+ data['what']=data['what'].apply(lambda x: str(x).lstrip(','))
251
+ data['what']=data['what'].apply(lambda x: str(x).replace('None,','').replace('None',''))
252
+ data['what']=data['what'].apply(remove_trail_comma)
253
+
254
+ data['why']=data['why'].apply(lambda x: str(x).replace(" 's","'s"))
255
+ data['why']=data['why'].apply(lambda x: str(x).replace("s ’","s’"))
256
+ data['why']=data['why'].apply(lambda x: str(x).replace(" - ","-"))
257
+ data['why']=data['why'].apply(lambda x: str(x).replace('\'','').replace('\'',''))
258
+ # data['why']=data['why'].apply(lambda x: str(x).replace('"','').replace('"',''))
259
+ data['why']=data['why'].apply(lambda x: str(x).replace('[','').replace(']',''))
260
+ data['why']=data['why'].apply(lambda x: str(x).rstrip(','))
261
+ data['why']=data['why'].apply(lambda x: str(x).lstrip(','))
262
+ data['why']=data['why'].apply(lambda x: str(x).replace('None,','').replace('None',''))
263
+ data['why']=data['why'].apply(remove_trail_comma)
264
+
265
+ data['when']=data['when'].apply(lambda x: str(x).replace(" 's","'s"))
266
+ data['when']=data['when'].apply(lambda x: str(x).replace("s ’","s’"))
267
+ data['when']=data['when'].apply(lambda x: str(x).replace(" - ","-"))
268
+ data['when']=data['when'].apply(lambda x: str(x).replace('\'','').replace('\'',''))
269
+ # data['when']=data['when'].apply(lambda x: str(x).replace('"','').replace('"',''))
270
+ data['when']=data['when'].apply(lambda x: str(x).replace('[','').replace(']',''))
271
+ data['when']=data['when'].apply(lambda x: str(x).rstrip(','))
272
+ data['when']=data['when'].apply(lambda x: str(x).lstrip(','))
273
+ data['when']=data['when'].apply(lambda x: str(x).replace('None,','').replace('None',''))
274
+ data['when']=data['when'].apply(remove_trail_comma)
275
+
276
+ data['where']=data['where'].apply(lambda x: str(x).replace(" 's","'s"))
277
+ data['where']=data['where'].apply(lambda x: str(x).replace("s ’","s’"))
278
+ data['where']=data['where'].apply(lambda x: str(x).replace(" - ","-"))
279
+ data['where']=data['where'].apply(lambda x: str(x).replace('\'','').replace('\'',''))
280
+ # data['where']=data['where'].apply(lambda x: str(x).replace('"','').replace('"',''))
281
+ data['where']=data['where'].apply(lambda x: str(x).replace('[','').replace(']',''))
282
+ data['where']=data['where'].apply(lambda x: str(x).rstrip(','))
283
+ data['where']=data['where'].apply(lambda x: str(x).lstrip(','))
284
+ data['where']=data['where'].apply(lambda x: str(x).replace('None,','').replace('None',''))
285
+ data['where']=data['where'].apply(remove_trail_comma)
286
+ return data
287
+ #-------------------------------------------------------------------------
288
+ def split_ws(input_list):
289
+ import re
290
+ output_list = []
291
+ for item in input_list:
292
+ split_item = re.findall(r'[^",]+|"[^"]*"', item)
293
+ output_list += split_item
294
+ result = [x.strip() for x in output_list]
295
+ return result
296
+
297
+ #--------------------------------------------------------------------------
298
+ def gen_qq(df):
299
+ w_list=["who","when","where","what","why"]
300
+ ans=[]
301
+ cl=[]
302
+ ind=[]
303
+ ques=[]
304
+ evid=[]
305
+ for index,value in enumerate(w_list):
306
+ for i,row in df.iterrows():
307
+ srl=df[value][i]
308
+ claim=df['claim'][i]
309
+ evidence_text=df['evidence'][i]
310
+ answer= split_ws(df[value])
311
+ try:
312
+ if len(srl.split())>0 and len(srl.split(","))>0:
313
+ for j in range(0,len(answer)):
314
+ FACT_TO_GENERATE_QUESTION_FROM = f"""{answer[j]} [SEP] {claim}"""
315
+ question_ids = query({"inputs":FACT_TO_GENERATE_QUESTION_FROM,
316
+ "num_beams":5,
317
+ "early_stopping":True})
318
+ #print("claim : {}".format(claim))
319
+ #print("answer : {}".format(answer[j]))
320
+ #print("question : {}".format(question_ids[0]['generated_text']))
321
+ ind.append(i)
322
+ cl.append(claim)
323
+ ans.append(answer[j])
324
+ ques.append(question_ids[0]['generated_text'].capitalize())
325
+ evid.append(evidence_text)
326
+ #print("-----------------------------------------")
327
+ except:
328
+ pass
329
+ return cl,ques,ans,evid
330
+ #------------------------------------------------------------
331
+ def qa_evidence(final_data):
332
+ ans=[]
333
+ cl=[]
334
+ #ind=[]
335
+ ques=[]
336
+ evi=[]
337
+ srl_ans=[]
338
+
339
+
340
+ for i,row in final_data.iterrows():
341
+ question=final_data['gen_question'][i]
342
+ evidence=final_data['evidence'][i]
343
+ claim=final_data['actual_claim'][i]
344
+ srl_answer=final_data['actual_answer'][i]
345
+ #index=df["index"][i]
346
+
347
+ input_evidence = f"question: {question} context: {evidence}"
348
+
349
+ answer = query_evidence({
350
+ "inputs":input_evidence,
351
+ "truncation":True})
352
+
353
+ #ind.append(index)
354
+ cl.append(claim)
355
+ ans.append(answer[0]["generated_text"])
356
+ ques.append(question)
357
+ evi.append(evidence)
358
+ srl_ans.append(srl_answer)
359
+
360
+ #print(f"""index: {index}""")
361
+ # print(f"""evidence: {evidence}""")
362
+ # print(f"""claim: {claim}""")
363
+ # print(f"""Question: {question}""")
364
+ # print(f"""Answer: {answer}""")
365
+ # print(f"""SRL Answer: {srl_answer}""")
366
+ # print("------------------------------------")
367
+ # return list(zip(cl,ques,srl_ans)),list(zip(evi,ques,ans))
368
+ # return cl,ques
369
+ return list(zip(ques,srl_ans)),list(zip(ques,ans))
370
+
371
+ #------------------------------------------------------------
372
+
373
+ if claim_text:
374
+ if evidence_text:
375
+ df=claim(claim_text)
376
+ df["evidence"]=evidence_text
377
+ actual_claim,gen_question,actual_answer,evidence=gen_qq(df)
378
+ final_data=pd.DataFrame([actual_claim,gen_question,actual_answer,evidence]).T
379
+ final_data.columns=["actual_claim","gen_question","actual_answer","evidence"]
380
+ a,b=qa_evidence(final_data)
381
+ # qa_evidence(final_data)
382
+ # st.json(qa_evidence(final_data))
383
+ st.json({'QA pair from claim':[{"Question": qu, "Answer": an} for qu, an in a],
384
+ 'QA pair from evidence':[{"Question": qu, "Answer": an} for qu, an in b]})