ashishraics commited on
Commit
59a7eba
1 Parent(s): 65ac75e

MLM bert added for ZS

Browse files
.gitignore CHANGED
@@ -2,5 +2,6 @@ venv/
2
  #exclude model files as they are large
3
  sentiment_model_dir/pytorch_model.bin
4
  zs_model_dir/pytorch_model.bin
 
5
  #sent_clf_onnx_dir/
6
  #zs_onnx_dir/
 
2
  #exclude model files as they are large
3
  sentiment_model_dir/pytorch_model.bin
4
  zs_model_dir/pytorch_model.bin
5
+ zs_mlm_dir/pytorch_model.bin
6
  #sent_clf_onnx_dir/
7
  #zs_onnx_dir/
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  import streamlit as st
4
  from streamlit_text_rating.st_text_rater import st_text_rater
5
  from transformers import AutoTokenizer,AutoModelForSequenceClassification
 
6
  import onnxruntime as ort
7
  import os
8
  import time
@@ -11,8 +12,15 @@ import plotly.graph_objects as go
11
  global _plotly_config
12
  _plotly_config={'displayModeBar': False}
13
 
14
- from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
15
- from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
 
 
 
 
 
 
 
16
 
17
  import multiprocessing
18
  total_threads=multiprocessing.cpu_count()#for ort inference
@@ -36,6 +44,10 @@ zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
36
  zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
37
  zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
38
 
 
 
 
 
39
 
40
  st.set_page_config( # Alternate names: setup_page, page, layout
41
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
@@ -43,7 +55,6 @@ st.set_page_config( # Alternate names: setup_page, page, layout
43
  page_title='None', # String or None. Strings get appended with "• Streamlit".
44
  )
45
 
46
-
47
  padding_top = 0
48
  st.markdown(f"""
49
  <style>
@@ -98,16 +109,24 @@ session_options_ort.inter_op_num_threads=1
98
  # session_options_ort.execution_mode = session_options_ort.ExecutionMode.ORT_SEQUENTIAL
99
 
100
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
101
- def create_model_dir(chkpt, model_dir):
102
  if not os.path.exists(model_dir):
103
  try:
104
  os.mkdir(path=model_dir)
105
  except:
106
  pass
107
- _model = AutoModelForSequenceClassification.from_pretrained(chkpt)
108
- _tokenizer = AutoTokenizer.from_pretrained(chkpt)
109
- _model.save_pretrained(model_dir)
110
- _tokenizer.save_pretrained(model_dir)
 
 
 
 
 
 
 
 
111
  else:
112
  pass
113
 
@@ -125,7 +144,7 @@ with st.sidebar:
125
  ############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
126
 
127
  # #create model/token dir for sentiment classification for faster inference
128
- create_model_dir(chkpt=sent_chkpt, model_dir=sent_mdl_dir)
129
 
130
 
131
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
@@ -135,7 +154,7 @@ def sentiment_task_selected(task,
135
  sent_onnx_mdl_dir=sent_onnx_mdl_dir,
136
  sent_onnx_mdl_name=sent_onnx_mdl_name,
137
  sent_onnx_quant_mdl_name=sent_onnx_quant_mdl_name):
138
- #model & tokenizer initialization for normal sentiment classification
139
  # model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_chkpt)
140
  # tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_chkpt)
141
  tokenizer_sentiment = AutoTokenizer.from_pretrained(sent_mdl_dir)
@@ -152,18 +171,17 @@ def sentiment_task_selected(task,
152
  ############## Pre-Download & instantiate objects for sentiment analysis ********************* END **********************************
153
 
154
 
155
- ############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
156
 
157
  # create model/token dir for zeroshot clf -- already created so not required
158
- create_model_dir(chkpt=zs_chkpt, model_dir=zs_mdl_dir)
159
 
160
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
161
- def zs_task_selected(task,
162
- zs_chkpt=zs_chkpt ,
163
- zs_mdl_dir=zs_mdl_dir,
164
- zs_onnx_mdl_dir=zs_onnx_mdl_dir,
165
- zs_onnx_mdl_name=zs_onnx_mdl_name,
166
- zs_onnx_quant_mdl_name=zs_onnx_quant_mdl_name):
167
 
168
  ##model & tokenizer initialization for normal ZS classification
169
  # model_zs=AutoModelForSequenceClassification.from_pretrained(zs_chkpt)
@@ -171,16 +189,46 @@ def zs_task_selected(task,
171
  # tokenizer_zs=AutoTokenizer.from_pretrained(zs_chkpt)
172
  tokenizer_zs = AutoTokenizer.from_pretrained(zs_mdl_dir)
173
 
174
- # # create onnx model for zeroshot but once created locally comment it out.
175
- # create_onnx_model_zs()
176
 
177
  #create inference session from onnx model
178
  zs_session = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}",sess_options=session_options_ort)
179
- # zs_session_quant = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}")
180
 
181
  return tokenizer_zs,zs_session
182
 
183
- ############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  if select_task=='README':
186
  st.header("NLP Summary")
@@ -196,7 +244,7 @@ if select_task == 'Detect Sentiment':
196
  t2 = time.time()
197
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
198
 
199
- st.header("You are now performing Sentiment Analysis")
200
  input_texts = st.text_input(label="Input texts separated by comma")
201
  c1,c2,_,_=st.columns(4)
202
 
@@ -223,35 +271,73 @@ if select_task == 'Detect Sentiment':
223
 
224
  if select_task=='Zero Shot Classification':
225
  t1=time.time()
226
- tokenizer_zs,session_zs = zs_task_selected(task=select_task)
227
- # tokenizer_zs= AutoTokenizer.from_pretrained(zs_mdl_dir)
228
- # session_zs = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
 
 
229
  t2 = time.time()
230
- st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
231
 
232
- st.header("You are now performing Zero Shot Classification")
 
 
 
 
 
 
 
 
 
233
  input_texts = st.text_input(label="Input text to classify into topics")
234
  input_lables = st.text_input(label="Enter labels separated by commas")
 
235
 
236
- c1,_,_,_=st.columns(4)
237
 
238
  with c1:
239
- response1=st.button("Compute (ONNX runtime)")
 
 
 
240
 
241
  if response1:
242
  start = time.time()
243
- df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=session_zs,
244
- _tokenizer=tokenizer_zs)
 
 
 
 
245
  end = time.time()
246
- st.write("")
247
  st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
248
  fig = px.bar(x='Probability',
249
  y='labels',
250
  text='Probability',
251
  data_frame=df_output,
252
- title='Zero Shot Normalized Probabilities')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  st.plotly_chart(fig, config=_plotly_config)
255
  else:
256
  pass
257
 
 
 
3
  import streamlit as st
4
  from streamlit_text_rating.st_text_rater import st_text_rater
5
  from transformers import AutoTokenizer,AutoModelForSequenceClassification
6
+ from transformers import AutoModelForMaskedLM
7
  import onnxruntime as ort
8
  import os
9
  import time
 
12
  global _plotly_config
13
  _plotly_config={'displayModeBar': False}
14
 
15
+ from sentiment_clf_helper import (classify_sentiment,
16
+ create_onnx_model_sentiment,
17
+ classify_sentiment_onnx)
18
+
19
+ from zeroshot_clf_helper import (zero_shot_classification,
20
+ create_onnx_model_zs_nli,
21
+ create_onnx_model_zs_mlm,
22
+ zero_shot_classification_nli_onnx,
23
+ zero_shot_classification_fillmask_onnx)
24
 
25
  import multiprocessing
26
  total_threads=multiprocessing.cpu_count()#for ort inference
 
44
  zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
45
  zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
46
 
47
+ zs_mlm_chkpt=config['ZEROSHOT_MLM']['zs_mlm_chkpt']
48
+ zs_mlm_mdl_dir=config['ZEROSHOT_MLM']['zs_mlm_mdl_dir']
49
+ zs_mlm_onnx_mdl_dir=config['ZEROSHOT_MLM']['zs_mlm_onnx_mdl_dir']
50
+ zs_mlm_onnx_mdl_name=config['ZEROSHOT_MLM']['zs_mlm_onnx_mdl_name']
51
 
52
  st.set_page_config( # Alternate names: setup_page, page, layout
53
  layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
 
55
  page_title='None', # String or None. Strings get appended with "• Streamlit".
56
  )
57
 
 
58
  padding_top = 0
59
  st.markdown(f"""
60
  <style>
 
109
  # session_options_ort.execution_mode = session_options_ort.ExecutionMode.ORT_SEQUENTIAL
110
 
111
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
112
+ def create_model_dir(chkpt, model_dir,task_type):
113
  if not os.path.exists(model_dir):
114
  try:
115
  os.mkdir(path=model_dir)
116
  except:
117
  pass
118
+ if task_type=='classification':
119
+ _model = AutoModelForSequenceClassification.from_pretrained(chkpt)
120
+ _tokenizer = AutoTokenizer.from_pretrained(chkpt)
121
+ _model.save_pretrained(model_dir)
122
+ _tokenizer.save_pretrained(model_dir)
123
+ elif task_type=='mlm':
124
+ _model=AutoModelForMaskedLM.from_pretrained(chkpt)
125
+ _tokenizer=AutoTokenizer.from_pretrained(chkpt)
126
+ _model.save_pretrained(model_dir)
127
+ _tokenizer.save_pretrained(model_dir)
128
+ else:
129
+ pass
130
  else:
131
  pass
132
 
 
144
  ############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
145
 
146
  # #create model/token dir for sentiment classification for faster inference
147
+ create_model_dir(chkpt=sent_chkpt, model_dir=sent_mdl_dir,task_type='classification')
148
 
149
 
150
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
 
154
  sent_onnx_mdl_dir=sent_onnx_mdl_dir,
155
  sent_onnx_mdl_name=sent_onnx_mdl_name,
156
  sent_onnx_quant_mdl_name=sent_onnx_quant_mdl_name):
157
+ ##model & tokenizer initialization for normal sentiment classification
158
  # model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_chkpt)
159
  # tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_chkpt)
160
  tokenizer_sentiment = AutoTokenizer.from_pretrained(sent_mdl_dir)
 
171
  ############## Pre-Download & instantiate objects for sentiment analysis ********************* END **********************************
172
 
173
 
174
+ ############### Pre-Download & instantiate objects for Zero shot clf NLI *********************** START **********************
175
 
176
  # create model/token dir for zeroshot clf -- already created so not required
177
+ create_model_dir(chkpt=zs_chkpt, model_dir=zs_mdl_dir,task_type='classification')
178
 
179
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
180
+ def zs_nli_task_selected(task,
181
+ zs_chkpt ,
182
+ zs_mdl_dir,
183
+ zs_onnx_mdl_dir,
184
+ zs_onnx_mdl_name):
 
185
 
186
  ##model & tokenizer initialization for normal ZS classification
187
  # model_zs=AutoModelForSequenceClassification.from_pretrained(zs_chkpt)
 
189
  # tokenizer_zs=AutoTokenizer.from_pretrained(zs_chkpt)
190
  tokenizer_zs = AutoTokenizer.from_pretrained(zs_mdl_dir)
191
 
192
+ ## create onnx model for zeroshot but once created locally comment it out.
193
+ #create_onnx_model_zs_nli()
194
 
195
  #create inference session from onnx model
196
  zs_session = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}",sess_options=session_options_ort)
 
197
 
198
  return tokenizer_zs,zs_session
199
 
200
+ ############## Pre-Download & instantiate objects for Zero shot NLI analysis ********************* END **********************************
201
+
202
+
203
+ ############### Pre-Download & instantiate objects for Zero shot clf NLI *********************** START **********************
204
+ ## create model/token dir for zeroshot clf -- already created so not required
205
+ # create_model_dir(chkpt=zs_mlm_chkpt, model_dir=zs_mlm_mdl_dir, task_type='mlm')
206
+
207
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
208
+ def zs_mlm_task_selected(task,
209
+ zs_mlm_chkpt=zs_mlm_chkpt,
210
+ zs_mlm_mdl_dir=zs_mlm_mdl_dir,
211
+ zs_mlm_onnx_mdl_dir=zs_mlm_onnx_mdl_dir,
212
+ zs_mlm_onnx_mdl_name=zs_mlm_onnx_mdl_name):
213
+ ##model & tokenizer initialization for normal ZS classification
214
+ model_zs_mlm=AutoModelForMaskedLM.from_pretrained(zs_mlm_mdl_dir)
215
+ ##we just need tokenizer for inference and not model since onnx model is already saved
216
+ # tokenizer_zs_mlm=AutoTokenizer.from_pretrained(zs_mlm_chkpt)
217
+ tokenizer_zs_mlm = AutoTokenizer.from_pretrained(zs_mlm_mdl_dir)
218
+
219
+ # create onnx model for zeroshot but once created locally comment it out.
220
+ create_onnx_model_zs_mlm(_model=model_zs_mlm,
221
+ _tokenizer=tokenizer_zs_mlm,
222
+ zs_mlm_onnx_mdl_dir=zs_mlm_onnx_mdl_dir)
223
+
224
+ # create inference session from onnx model
225
+ zs_session_mlm = ort.InferenceSession(f"{zs_mlm_onnx_mdl_dir}/{zs_mlm_onnx_mdl_name}", sess_options=session_options_ort)
226
+
227
+ return tokenizer_zs_mlm, zs_session_mlm
228
+
229
+
230
+ ############## Pre-Download & instantiate objects for Zero shot MLM analysis ********************* END **********************************
231
+
232
 
233
  if select_task=='README':
234
  st.header("NLP Summary")
 
244
  t2 = time.time()
245
  st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
246
 
247
+ st.subheader("You are now performing Sentiment Analysis")
248
  input_texts = st.text_input(label="Input texts separated by comma")
249
  c1,c2,_,_=st.columns(4)
250
 
 
271
 
272
  if select_task=='Zero Shot Classification':
273
  t1=time.time()
274
+ tokenizer_zs,session_zs = zs_nli_task_selected(task=select_task ,
275
+ zs_chkpt=zs_chkpt,
276
+ zs_mdl_dir=zs_mdl_dir,
277
+ zs_onnx_mdl_dir=zs_onnx_mdl_dir,
278
+ zs_onnx_mdl_name=zs_onnx_mdl_name)
279
  t2 = time.time()
280
+ st.write(f"Total time to load NLI Model is {(t2-t1)*1000:.1f} ms")
281
 
282
+ t1=time.time()
283
+ tokenizer_zs_mlm,session_zs_mlm = zs_mlm_task_selected(task=select_task,
284
+ zs_mlm_chkpt=zs_mlm_chkpt,
285
+ zs_mlm_mdl_dir=zs_mlm_mdl_dir,
286
+ zs_mlm_onnx_mdl_dir=zs_mlm_onnx_mdl_dir,
287
+ zs_mlm_onnx_mdl_name=zs_mlm_onnx_mdl_name)
288
+ t2 = time.time()
289
+ st.write(f"Total time to load MLM Model is {(t2-t1)*1000:.1f} ms")
290
+
291
+ st.subheader("Zero Shot Classification using NLI")
292
  input_texts = st.text_input(label="Input text to classify into topics")
293
  input_lables = st.text_input(label="Enter labels separated by commas")
294
+ input_hypothesis = st.text_input(label="Enter your hypothesis",value="This is an example of")
295
 
296
+ c1,c2,_,=st.columns(3)
297
 
298
  with c1:
299
+ response1=st.button("Compute using NLI approach (ONNX runtime)")
300
+
301
+ with c2:
302
+ response2=st.button("Compute using Fill-Mask approach(ONNX runtime)")
303
 
304
  if response1:
305
  start = time.time()
306
+ df_output = zero_shot_classification_nli_onnx(premise=input_texts,
307
+ labels=input_lables,
308
+ hypothesis=input_hypothesis,
309
+ _session=session_zs,
310
+ _tokenizer=tokenizer_zs,
311
+ )
312
  end = time.time()
 
313
  st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
314
  fig = px.bar(x='Probability',
315
  y='labels',
316
  text='Probability',
317
  data_frame=df_output,
318
+ title='Zero Shot NLI Normalized Probabilities')
319
+
320
+ st.plotly_chart(fig, config=_plotly_config)
321
+
322
+ elif response2:
323
+ start=time.time()
324
+ df_output=zero_shot_classification_fillmask_onnx(premise=input_texts,
325
+ labels=input_lables,
326
+ hypothesis=input_hypothesis,
327
+ _session=session_zs_mlm,
328
+ _tokenizer=tokenizer_zs_mlm,
329
+ )
330
+ end=time.time()
331
+ st.write(f"Time taken for computation {(end - start) * 1000:.1f} ms")
332
+
333
+ fig = px.bar(x='Probability',
334
+ y='Labels',
335
+ text='Probability',
336
+ data_frame=df_output,
337
+ title='Zero Shot MLM Normalized Probabilities')
338
 
339
  st.plotly_chart(fig, config=_plotly_config)
340
  else:
341
  pass
342
 
343
+
config.yaml CHANGED
@@ -12,3 +12,10 @@ ZEROSHOT_CLF:
12
  zs_onnx_mdl_name: 'model.onnx'
13
  zs_onnx_quant_mdl_name: 'model_quant.onnx'
14
 
 
 
 
 
 
 
 
 
12
  zs_onnx_mdl_name: 'model.onnx'
13
  zs_onnx_quant_mdl_name: 'model_quant.onnx'
14
 
15
+ ZEROSHOT_MLM:
16
+ zs_mlm_chkpt: 'bert-base-uncased'
17
+ zs_mlm_mdl_dir: 'zs_mlm_dir'
18
+ zs_mlm_onnx_mdl_dir: 'zs_mlm_onnx_dir'
19
+ zs_mlm_onnx_mdl_name: 'model.onnx'
20
+
21
+
zeroshot_clf_helper.py CHANGED
@@ -4,6 +4,10 @@ import os
4
  import subprocess
5
  import numpy as np
6
  import pandas as pd
 
 
 
 
7
 
8
  import yaml
9
  def read_yaml(file_path):
@@ -18,8 +22,24 @@ zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
18
  zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
19
  zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
20
 
 
 
 
 
 
21
 
22
  def zero_shot_classification(premise: str, labels: str, model, tokenizer):
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
  labels=labels.split(',')
25
  labels=[l.lower() for l in labels]
@@ -49,11 +69,19 @@ def zero_shot_classification(premise: str, labels: str, model, tokenizer):
49
  return df
50
 
51
  ##example
52
- # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
53
  # labels='science, sports, museum')
54
 
55
 
56
- def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir):
 
 
 
 
 
 
 
 
57
 
58
  # create onnx model using
59
  if not os.path.exists(zs_onnx_mdl_dir):
@@ -61,6 +89,7 @@ def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir):
61
  subprocess.run(['python3', '-m', 'transformers.onnx',
62
  '--model=valhalla/distilbart-mnli-12-1',
63
  '--feature=sequence-classification',
 
64
  zs_onnx_mdl_dir])
65
  except Exception as e:
66
  print(e)
@@ -72,7 +101,19 @@ def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir):
72
  else:
73
  pass
74
 
75
- def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
  labels=labels.split(',')
78
  labels=[l.lower() for l in labels]
@@ -85,7 +126,7 @@ def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
85
 
86
  for l in labels:
87
 
88
- hypothesis= f'this is an example of {l}'
89
 
90
  inputs = _tokenizer(premise,hypothesis,
91
  return_tensors='pt',
@@ -109,4 +150,84 @@ def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
109
  return df
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
 
 
4
  import subprocess
5
  import numpy as np
6
  import pandas as pd
7
+ import transformers
8
+ import transformers.convert_graph_to_onnx as onnx_convert
9
+ from pathlib import Path
10
+ import streamlit as st
11
 
12
  import yaml
13
  def read_yaml(file_path):
 
22
  zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
23
  zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
24
 
25
+ zs_mlm_chkpt=config['ZEROSHOT_MLM']['zs_mlm_chkpt']
26
+ zs_mlm_mdl_dir=config['ZEROSHOT_MLM']['zs_mlm_mdl_dir']
27
+ zs_mlm_onnx_mdl_dir=config['ZEROSHOT_MLM']['zs_mlm_onnx_mdl_dir']
28
+ zs_mlm_onnx_mdl_name=config['ZEROSHOT_MLM']['zs_mlm_onnx_mdl_name']
29
+
30
 
31
  def zero_shot_classification(premise: str, labels: str, model, tokenizer):
32
+ """
33
+
34
+ Args:
35
+ premise:
36
+ labels:
37
+ model:
38
+ tokenizer:
39
+
40
+ Returns:
41
+
42
+ """
43
  try:
44
  labels=labels.split(',')
45
  labels=[l.lower() for l in labels]
 
69
  return df
70
 
71
  ##example
72
+ # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for disease while it’s early and treatable',
73
  # labels='science, sports, museum')
74
 
75
 
76
+ def create_onnx_model_zs_nli(zs_onnx_mdl_dir=zs_onnx_mdl_dir):
77
+ """
78
+
79
+ Args:
80
+ zs_onnx_mdl_dir:
81
+
82
+ Returns:
83
+
84
+ """
85
 
86
  # create onnx model using
87
  if not os.path.exists(zs_onnx_mdl_dir):
 
89
  subprocess.run(['python3', '-m', 'transformers.onnx',
90
  '--model=valhalla/distilbart-mnli-12-1',
91
  '--feature=sequence-classification',
92
+ '--atol=1e-3',
93
  zs_onnx_mdl_dir])
94
  except Exception as e:
95
  print(e)
 
101
  else:
102
  pass
103
 
104
+ def zero_shot_classification_nli_onnx(premise,labels,_session,_tokenizer,hypothesis="This is an example of"):
105
+ """
106
+
107
+ Args:
108
+ premise:
109
+ labels:
110
+ _session:
111
+ _tokenizer:
112
+ hypothesis:
113
+
114
+ Returns:
115
+
116
+ """
117
  try:
118
  labels=labels.split(',')
119
  labels=[l.lower() for l in labels]
 
126
 
127
  for l in labels:
128
 
129
+ hypothesis= f"{hypothesis} {l}"
130
 
131
  inputs = _tokenizer(premise,hypothesis,
132
  return_tensors='pt',
 
150
  return df
151
 
152
 
153
+ def create_onnx_model_zs_mlm(_model, _tokenizer,zs_mlm_onnx_mdl_dir=zs_mlm_onnx_mdl_dir):
154
+ """
155
+
156
+ Args:
157
+ _model:
158
+ _tokenizer:
159
+ zs_mlm_onnx_mdl_dir:
160
+
161
+ Returns:
162
+
163
+ """
164
+ if not os.path.exists(zs_mlm_onnx_mdl_dir):
165
+ try:
166
+ subprocess.run(['python3', '-m', 'transformers.onnx',
167
+ f'--model={zs_mlm_chkpt}',
168
+ '--feature=masked-lm',
169
+ zs_mlm_onnx_mdl_dir])
170
+ except:
171
+ pass
172
+
173
+ else:
174
+ pass
175
+
176
+ def zero_shot_classification_fillmask_onnx(premise,hypothesis,labels,_session,_tokenizer):
177
+ """
178
+
179
+ Args:
180
+ premise:
181
+ hypothesis:
182
+ labels:
183
+ _session:
184
+ _tokenizer:
185
+
186
+ Returns:
187
+
188
+ """
189
+ try:
190
+ labels=labels.split(',')
191
+ labels=[l.lower().rstrip().lstrip() for l in labels]
192
+ except:
193
+ raise Exception("please pass atleast 2 labels to classify")
194
+
195
+ premise=premise.lower()
196
+ hypothesis=hypothesis.lower()
197
+
198
+ final_input= f"{premise}.{hypothesis} [MASK]" #this can change depending on chkpt, this is for bert-base-uncased chkpt
199
+
200
+ _inputs=_tokenizer(final_input,padding=True, truncation=True,
201
+ return_tensors="pt")
202
+
203
+ input_feed={
204
+ 'input_ids': np.array(_inputs['input_ids']),
205
+ 'token_type_ids': np.array(_inputs['token_type_ids']),
206
+ 'attention_mask': np.array(_inputs['attention_mask'])
207
+ }
208
+
209
+ output=_session.run(output_names=['logits'],input_feed=dict(input_feed))[0]
210
+
211
+ mask_token_index = np.argwhere(_inputs["input_ids"] == _tokenizer.mask_token_id)[1,0]
212
+
213
+ mask_token_logits=output[0,mask_token_index,:]
214
+
215
+ #seacrh for logits of input labels
216
+ #encode the labels and get the label id -
217
+ labels_logits=[]
218
+ for l in labels:
219
+ encoded_label=_tokenizer.encode(l)[1]
220
+ labels_logits.append(mask_token_logits[encoded_label])
221
+
222
+ #do a softmax on the logits
223
+ labels_logits=np.array(labels_logits)
224
+ labels_logits=torch.from_numpy(labels_logits)
225
+ labels_logits=labels_logits.softmax(dim=0)
226
+
227
+ output= {'Labels':labels,
228
+ 'Probability':labels_logits}
229
+
230
+ df_output = pd.DataFrame(output)
231
+ df_output['Probability'] = df_output['Probability'].apply(lambda x: np.round(100*x, 1))
232
 
233
+ return df_output
zs_mlm_dir/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bert-base-uncased",
3
+ "architectures": [
4
+ "BertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 0,
20
+ "position_embedding_type": "absolute",
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.18.0",
23
+ "type_vocab_size": 2,
24
+ "use_cache": true,
25
+ "vocab_size": 30522
26
+ }
zs_mlm_dir/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
zs_mlm_dir/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
zs_mlm_dir/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "bert-base-uncased", "tokenizer_class": "BertTokenizer"}
zs_mlm_dir/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
zs_mlm_onnx_dir/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e55575cce3f2b3b68e82e4dbdaabff3a8a5eaaeac4703e4000b1cb717174543a
3
+ size 531893756