ppsingh commited on
Commit
05828e0
1 Parent(s): bdab6a0

adding appstore and utils main scripts

Browse files
app.py CHANGED
@@ -2,6 +2,9 @@ import streamlit as st
2
  from utils.uploadAndExample import add_upload
3
  from utils.config import model_dict
4
  from utils.vulnerability_classifier import label_dict
 
 
 
5
 
6
  with st.sidebar:
7
  # upload and example doc
@@ -23,4 +26,30 @@ with st.sidebar:
23
 
24
  with st.container():
25
  st.markdown("<h2 style='text-align: center;'> Vulnerability Analysis 3.1 </h2>", unsafe_allow_html=True)
26
- st.write(' ')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from utils.uploadAndExample import add_upload
3
  from utils.config import model_dict
4
  from utils.vulnerability_classifier import label_dict
5
+ import appStore.doc_processing as processing
6
+ import appStore.vulnerability_analysis as vulnerability_analysis
7
+ import appStore.target as target_analysis
8
 
9
  with st.sidebar:
10
  # upload and example doc
 
26
 
27
  with st.container():
28
  st.markdown("<h2 style='text-align: center;'> Vulnerability Analysis 3.1 </h2>", unsafe_allow_html=True)
29
+ st.write(' ')
30
+
31
+ with st.expander("ℹ️ - About this app", expanded=False):
32
+ st.write(
33
+ """
34
+ The Vulnerability Analysis App is an open-source\
35
+ digital tool which aims to assist policy analysts and \
36
+ other users in extracting and filtering references \
37
+ to different groups in vulnerable situations from public documents. \
38
+ We use Natural Language Processing (NLP), specifically deep \
39
+ learning-based text representations to search context-sensitively \
40
+ for mentions of the special needs of groups in vulnerable situations
41
+ to cluster them thematically.
42
+ For more understanding on Methodology [Click Here](https://vulnerability-analysis.streamlit.app/)
43
+ """)
44
+
45
+ st.write("""
46
+ What Happens in background?
47
+
48
+ - Step 1: Once the document is provided to app, it undergoes *Pre-processing*.\
49
+ In this step the document is broken into smaller paragraphs \
50
+ (based on word/sentence count).
51
+ - Step 2: The paragraphs are then fed to the **Vulnerability Classifier** which detects if
52
+ the paragraph contains any or multiple references to vulnerable groups.
53
+ """)
54
+
55
+ st.write("")
appStore/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # more app related files
appStore/doc_processing.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set path
2
+ import glob, os, sys;
3
+ sys.path.append('../utils')
4
+ from typing import List, Tuple
5
+ from typing_extensions import Literal
6
+ from haystack.schema import Document
7
+ from utils.config import get_classifier_params
8
+ from utils.preprocessing import processingpipeline,paraLengthCheck
9
+ import streamlit as st
10
+ import logging
11
+ import pandas as pd
12
+ import nltk
13
+ nltk.download('punkt_tab')
14
+
15
+ params = get_classifier_params("preprocessing")
16
+
17
+ @st.cache_data
18
+ def runPreprocessingPipeline(file_name:str, file_path:str,
19
+ split_by: Literal["sentence", "word"] = 'sentence',
20
+ split_length:int = 2, split_respect_sentence_boundary:bool = False,
21
+ split_overlap:int = 0,remove_punc:bool = False)->List[Document]:
22
+ """
23
+ creates the pipeline and runs the preprocessing pipeline,
24
+ the params for pipeline are fetched from paramconfig
25
+ Params
26
+ ------------
27
+ file_name: filename, in case of streamlit application use
28
+ st.session_state['filename']
29
+ file_path: filepath, in case of streamlit application use st.session_state['filepath']
30
+ split_by: document splitting strategy either as word or sentence
31
+ split_length: when synthetically creating the paragrpahs from document,
32
+ it defines the length of paragraph.
33
+ split_respect_sentence_boundary: Used when using 'word' strategy for
34
+ splititng of text.
35
+ split_overlap: Number of words or sentences that overlap when creating
36
+ the paragraphs. This is done as one sentence or 'some words' make sense
37
+ when read in together with others. Therefore the overlap is used.
38
+ remove_punc: to remove all Punctuation including ',' and '.' or not
39
+ Return
40
+ --------------
41
+ List[Document]: When preprocessing pipeline is run, the output dictionary
42
+ has four objects. For the Haysatck implementation of SDG classification we,
43
+ need to use the List of Haystack Document, which can be fetched by
44
+ key = 'documents' on output.
45
+ """
46
+
47
+ processing_pipeline = processingpipeline()
48
+
49
+ output_pre = processing_pipeline.run(file_paths = file_path,
50
+ params= {"FileConverter": {"file_path": file_path, \
51
+ "file_name": file_name},
52
+ "UdfPreProcessor": {"remove_punc": remove_punc, \
53
+ "split_by": split_by, \
54
+ "split_length":split_length,\
55
+ "split_overlap": split_overlap, \
56
+ "split_respect_sentence_boundary":split_respect_sentence_boundary}})
57
+
58
+ return output_pre
59
+
60
+
61
+ def app():
62
+ with st.container():
63
+ if 'filepath' in st.session_state:
64
+ file_name = st.session_state['filename']
65
+ file_path = st.session_state['filepath']
66
+
67
+
68
+ all_documents = runPreprocessingPipeline(file_name= file_name,
69
+ file_path= file_path, split_by= params['split_by'],
70
+ split_length= params['split_length'],
71
+ split_respect_sentence_boundary= params['split_respect_sentence_boundary'],
72
+ split_overlap= params['split_overlap'], remove_punc= params['remove_punc'])
73
+ paralist = paraLengthCheck(all_documents['documents'], 100)
74
+ df = pd.DataFrame(paralist,columns = ['text','page'])
75
+ # saving the dataframe to session state
76
+ st.session_state['key0'] = df
77
+
78
+ else:
79
+ st.info("🤔 No document found, please try to upload it at the sidebar!")
80
+ logging.warning("Terminated as no document provided")
appStore/target.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set path
2
+ import glob, os, sys;
3
+ sys.path.append('../utils')
4
+
5
+ #import needed libraries
6
+ import seaborn as sns
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import streamlit as st
11
+ from utils.target_classifier import load_targetClassifier, target_classification
12
+ import logging
13
+ logger = logging.getLogger(__name__)
14
+ from utils.config import get_classifier_params
15
+ from utils.preprocessing import paraLengthCheck
16
+ from io import BytesIO
17
+ import xlsxwriter
18
+ import plotly.express as px
19
+ from utils.target_classifier import label_dict
20
+ from appStore.rag import run_query
21
+
22
+ # Declare all the necessary variables
23
+ classifier_identifier = 'target'
24
+ params = get_classifier_params(classifier_identifier)
25
+
26
+ @st.cache_data
27
+ def to_excel(df,sectorlist):
28
+ len_df = len(df)
29
+ output = BytesIO()
30
+ writer = pd.ExcelWriter(output, engine='xlsxwriter')
31
+ df.to_excel(writer, index=False, sheet_name='Sheet1')
32
+ workbook = writer.book
33
+ worksheet = writer.sheets['Sheet1']
34
+ worksheet.data_validation('S2:S{}'.format(len_df),
35
+ {'validate': 'list',
36
+ 'source': ['No', 'Yes', 'Discard']})
37
+ worksheet.data_validation('X2:X{}'.format(len_df),
38
+ {'validate': 'list',
39
+ 'source': sectorlist + ['Blank']})
40
+ worksheet.data_validation('T2:T{}'.format(len_df),
41
+ {'validate': 'list',
42
+ 'source': sectorlist + ['Blank']})
43
+ worksheet.data_validation('U2:U{}'.format(len_df),
44
+ {'validate': 'list',
45
+ 'source': sectorlist + ['Blank']})
46
+ worksheet.data_validation('V2:V{}'.format(len_df),
47
+ {'validate': 'list',
48
+ 'source': sectorlist + ['Blank']})
49
+ worksheet.data_validation('W2:U{}'.format(len_df),
50
+ {'validate': 'list',
51
+ 'source': sectorlist + ['Blank']})
52
+ writer.save()
53
+ processed_data = output.getvalue()
54
+ return processed_data
55
+
56
+ def app():
57
+
58
+ ### Main app code ###
59
+ with st.container():
60
+
61
+ if 'key1' in st.session_state:
62
+
63
+ # Load the existing dataset
64
+ df = st.session_state.key1
65
+
66
+ # Filter out all paragraphs that do not have a reference to groups
67
+ df = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]
68
+
69
+ # Load the classifier model
70
+ classifier = load_targetClassifier(classifier_name=params['model_name'])
71
+
72
+ st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
73
+
74
+ df = target_classification(haystack_doc=df,
75
+ threshold= params['threshold'])
76
+
77
+ # Rename column
78
+ df.rename(columns={'Target Label': 'Specific action/target/measure mentioned'}, inplace=True)
79
+
80
+
81
+ st.session_state.key2 = df
82
+
83
+
84
+ def target_display(model_sel_name):
85
+
86
+ ### TABLE Output ###
87
+
88
+ # Assign dataframe a name
89
+ df = st.session_state['key2']
90
+ st.write(df)
91
+
92
+ ### RAG Output by group ##
93
+
94
+ # Expand the DataFrame
95
+ df_expand = (
96
+ df.query("`Specific action/target/measure mentioned` == 'YES'")
97
+ .explode('Vulnerability Label')
98
+ )
99
+ # Group by 'Vulnerability Label' and concatenate 'text'
100
+ df_agg = df_expand.groupby('Vulnerability Label')['text'].agg('; '.join).reset_index()
101
+
102
+ # st.write(df_agg)
103
+
104
+ st.markdown("----")
105
+ st.markdown('**DOCUMENT FINDINGS SUMMARY BY VULNERABILITY LABEL:**')
106
+
107
+ # construct RAG query for each label, send to openai and process response
108
+ for i in range(0,len(df_agg)):
109
+ st.write(df_agg['Vulnerability Label'].iloc[i])
110
+ run_query(context = df_agg['text'].iloc[i], label = df_agg['Vulnerability Label'].iloc[i], model_sel_name=model_sel_name)
111
+
appStore/vulnerability_analysis.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set path
2
+ import glob, os, sys;
3
+ sys.path.append('../utils')
4
+
5
+ #import needed libraries
6
+ import seaborn as sns
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import streamlit as st
11
+ from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification
12
+ import logging
13
+ logger = logging.getLogger(__name__)
14
+ from utils.config import get_classifier_params
15
+ from utils.preprocessing import paraLengthCheck
16
+ from io import BytesIO
17
+ import xlsxwriter
18
+ import plotly.express as px
19
+ import plotly.graph_objects as go
20
+ from utils.vulnerability_classifier import label_dict
21
+
22
+
23
+
24
+ # Declare all the necessary variables
25
+ classifier_identifier = 'vulnerability'
26
+ params = get_classifier_params(classifier_identifier)
27
+
28
+ @st.cache_data
29
+ def to_excel(df,sectorlist):
30
+ len_df = len(df)
31
+ output = BytesIO()
32
+ writer = pd.ExcelWriter(output, engine='xlsxwriter')
33
+ df.to_excel(writer, index=False, sheet_name='Sheet1')
34
+ workbook = writer.book
35
+ worksheet = writer.sheets['Sheet1']
36
+ worksheet.data_validation('S2:S{}'.format(len_df),
37
+ {'validate': 'list',
38
+ 'source': ['No', 'Yes', 'Discard']})
39
+ worksheet.data_validation('X2:X{}'.format(len_df),
40
+ {'validate': 'list',
41
+ 'source': sectorlist + ['Blank']})
42
+ worksheet.data_validation('T2:T{}'.format(len_df),
43
+ {'validate': 'list',
44
+ 'source': sectorlist + ['Blank']})
45
+ worksheet.data_validation('U2:U{}'.format(len_df),
46
+ {'validate': 'list',
47
+ 'source': sectorlist + ['Blank']})
48
+ worksheet.data_validation('V2:V{}'.format(len_df),
49
+ {'validate': 'list',
50
+ 'source': sectorlist + ['Blank']})
51
+ worksheet.data_validation('W2:U{}'.format(len_df),
52
+ {'validate': 'list',
53
+ 'source': sectorlist + ['Blank']})
54
+ writer.save()
55
+ processed_data = output.getvalue()
56
+ return processed_data
57
+
58
+ def app():
59
+
60
+ ### Main app code ###
61
+ with st.container():
62
+
63
+ # If a document has been processed
64
+ if 'key0' in st.session_state:
65
+
66
+ # Run vulnerability classifier
67
+ df = st.session_state.key0
68
+ classifier = load_vulnerabilityClassifier(classifier_name=params['model_name'])
69
+ st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
70
+
71
+
72
+ # Get the predictions
73
+ df = vulnerability_classification(haystack_doc=df,
74
+ threshold= params['threshold'])
75
+
76
+
77
+ # Store df in session state with key1
78
+ st.session_state.key1 = df
79
+
80
+
81
+ def vulnerability_display():
82
+
83
+ # Get the vulnerability df
84
+ df = st.session_state['key1']
85
+
86
+ # Filter the dataframe to only show the paragraphs with references
87
+ df_filtered = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]
88
+
89
+ # Rename column
90
+ df_filtered.rename(columns={'Vulnerability Label': 'Group(s)'}, inplace=True)
91
+
92
+ # Header
93
+ st.subheader("Explore references to vulnerable groups:")
94
+
95
+
96
+ # Text
97
+ num_paragraphs = len(df['Vulnerability Label'])
98
+ num_references = len(df_filtered['Group(s)'])
99
+
100
+ st.markdown(f"""<div style="text-align: justify;">The document contains a
101
+ total of <span style="color: red;">{num_paragraphs}</span> paragraphs.
102
+ We identified <span style="color: red;">{num_references}</span>
103
+ references to groups in vulnerable situations.</div>
104
+ <br>
105
+ <div style="text-align: justify;">We are searching for references related
106
+ to the following groups: (1) Agricultural communities, (2) Children, (3)
107
+ Ethnic, racial and other minorities, (4) Fishery communities, (5) Informal sector
108
+ workers, (6) Members of indigenous and local communities, (7) Migrants and
109
+ displaced persons, (8) Older persons, (9) Persons living in poverty, (10)
110
+ Persons living with disabilities, (11) Persons with pre-existing health conditions,
111
+ (12) Residents of drought-prone regions, (13) Rural populations, (14) Sexual
112
+ minorities (LGBTQI+), (15) Urban populations, (16) Women and other genders.</div>
113
+ <br>
114
+ <div style="text-align: justify;">The chart below shows the groups for which
115
+ references were found and the number of references identified.
116
+ For a more detailed view in the text, see the paragraphs and
117
+ their respective labels in the table underneath.</div>""", unsafe_allow_html=True)
118
+
119
+
120
+ ### Bar chart
121
+
122
+ # # Create a df that stores all the labels
123
+ df_labels = pd.DataFrame(list(label_dict.items()), columns=['Label ID', 'Label'])
124
+
125
+ # Count how often each label appears in the "Group identified" column
126
+ group_counts = {}
127
+
128
+ # Iterate through each sublist
129
+ for index, row in df_filtered.iterrows():
130
+
131
+ # Iterate through each group in the sublist
132
+ for sublist in row['Group(s)']:
133
+
134
+ # Update the count in the dictionary
135
+ group_counts[sublist] = group_counts.get(sublist, 0) + 1
136
+
137
+ # Create a new dataframe from group_counts
138
+ df_label_count = pd.DataFrame(list(group_counts.items()), columns=['Label', 'Count'])
139
+
140
+ # Merge the label counts with the df_label DataFrame
141
+ df_label_count = df_labels.merge(df_label_count, on='Label', how='left')
142
+
143
+ # Exclude the "Other" group and all groups that do not have a label
144
+ df_bar_chart = df_label_count[df_label_count['Label'] != 'Other']
145
+ df_bar_chart = df_bar_chart.dropna(subset=['Count'])
146
+
147
+
148
+ # Bar chart
149
+ fig = go.Figure()
150
+
151
+ fig.add_trace(go.Bar(
152
+ y=df_bar_chart.Label,
153
+ x=df_bar_chart.Count,
154
+ orientation='h',
155
+ marker=dict(color='purple'),
156
+ ))
157
+
158
+ # Customize layout
159
+ fig.update_layout(
160
+ title='Number of references identified',
161
+ xaxis_title='Number of references',
162
+ yaxis_title='Group',
163
+ )
164
+
165
+ # Show the plot
166
+ #fig.show()
167
+
168
+ #Show plot
169
+ st.plotly_chart(fig, use_container_width=True)
utils/target_classifier.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from typing_extensions import Literal
3
+ import logging
4
+ import pandas as pd
5
+ from pandas import DataFrame, Series
6
+ from utils.config import getconfig
7
+ from utils.preprocessing import processingpipeline
8
+ import streamlit as st
9
+ from setfit import SetFitModel
10
+ from transformers import pipeline
11
+
12
+ ## Labels dictionary ###
13
+ label_dict = {
14
+ 0:'NO',
15
+ 1:'YES',
16
+ }
17
+
18
+ def get_target_labels(preds):
19
+
20
+ """
21
+ Function that takes the numerical predictions as an input and returns a list of the labels.
22
+
23
+ """
24
+
25
+ # Turn into list
26
+ preds_list = preds.numpy().tolist()
27
+
28
+
29
+ # Get label names
30
+ predictions_names=[]
31
+
32
+ # loop through each prediction
33
+ for ele in preds_list:
34
+
35
+ # see if there is a value 1 and retrieve index
36
+ try:
37
+ index_of_one = ele.index(1)
38
+ except ValueError:
39
+ index_of_one = "NA"
40
+
41
+ # Retrieve the name of the label (if no prediction made = NA)
42
+ if index_of_one != "NA":
43
+ name = label_dict[index_of_one]
44
+ else:
45
+ name = "Other"
46
+
47
+ # Append name to list
48
+ predictions_names.append(name)
49
+
50
+ return predictions_names
51
+
52
+ @st.cache_resource
53
+ def load_targetClassifier(config_file:str = None, classifier_name:str = None):
54
+ """
55
+ loads the document classifier using haystack, where the name/path of model
56
+ in HF-hub as string is used to fetch the model object.Either configfile or
57
+ model should be passed.
58
+ 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
59
+ 2. https://docs.haystack.deepset.ai/docs/document_classifier
60
+ Params
61
+ --------
62
+ config_file: config file path from which to read the model name
63
+ classifier_name: if modelname is passed, it takes a priority if not \
64
+ found then will look for configfile, else raise error.
65
+ Return: document classifier model
66
+ """
67
+ if not classifier_name:
68
+ if not config_file:
69
+ logging.warning("Pass either model name or config file")
70
+ return
71
+ else:
72
+ config = getconfig(config_file)
73
+ classifier_name = config.get('target','MODEL')
74
+
75
+ logging.info("Loading classifier")
76
+
77
+ # Loading classifier
78
+ doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_target")
79
+
80
+ return doc_classifier
81
+
82
+
83
+ @st.cache_data
84
+ def target_classification(haystack_doc:pd.DataFrame,
85
+ threshold:float = 0.5,
86
+ classifier_model:pipeline= None
87
+ )->Tuple[DataFrame,Series]:
88
+ """
89
+ Text-Classification on the list of texts provided. Classifier provides the
90
+ most appropriate label for each text. There labels indicate whether the paragraph
91
+ references a specific action, target or measure in the paragraph.
92
+ ---------
93
+ haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
94
+ contains the list of paragraphs in different format,here the list of
95
+ Haystack Documents is used.
96
+ threshold: threshold value for the model to keep the results from classifier
97
+ classifiermodel: you can pass the classifier model directly,which takes priority
98
+ however if not then looks for model in streamlit session.
99
+ In case of streamlit avoid passing the model directly.
100
+ Returns
101
+ ----------
102
+ df: Dataframe with two columns['SDG:int', 'text']
103
+ x: Series object with the unique SDG covered in the document uploaded and
104
+ the number of times it is covered/discussed/count_of_paragraphs.
105
+ """
106
+
107
+ logging.info("Working on target/action identification")
108
+
109
+ haystack_doc['Target Label'] = 'NA'
110
+
111
+ if not classifier_model:
112
+
113
+ classifier_model = st.session_state['target_classifier']
114
+
115
+ # Get predictions
116
+ predictions = classifier_model(list(haystack_doc.text))
117
+
118
+ # Get labels for predictions
119
+ pred_labels = get_target_labels(predictions)
120
+
121
+ # Save labels
122
+ haystack_doc['Target Label'] = pred_labels
123
+
124
+ return haystack_doc
125
+
utils/vulnerability_classifier.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  # labels dictionary
2
  label_dict= {0: 'Agricultural communities',
3
  1: 'Children',
@@ -16,4 +28,128 @@ label_dict= {0: 'Agricultural communities',
16
  14: 'Rural populations',
17
  15: 'Sexual minorities (LGBTQI+)',
18
  16: 'Urban populations',
19
- 17: 'Women and other genders'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from typing_extensions import Literal
3
+ import logging
4
+ import pandas as pd
5
+ from pandas import DataFrame, Series
6
+ from utils.config import getconfig
7
+ from utils.preprocessing import processingpipeline
8
+ import streamlit as st
9
+ from transformers import pipeline
10
+ from setfit import SetFitModel
11
+
12
+
13
  # labels dictionary
14
  label_dict= {0: 'Agricultural communities',
15
  1: 'Children',
 
28
  14: 'Rural populations',
29
  15: 'Sexual minorities (LGBTQI+)',
30
  16: 'Urban populations',
31
+ 17: 'Women and other genders'}
32
+
33
+ def get_vulnerability_labels(preds):
34
+
35
+ """
36
+ Function that takes the numerical predictions as an input and returns a list of the labels.
37
+
38
+ """
39
+
40
+ # Get label names
41
+ preds_list = preds.tolist()
42
+
43
+ # Get the name of the group where the prediction is equal to "1"
44
+ result = []
45
+
46
+ for sublist in preds_list:
47
+ names = [label_dict[key] for key, value in enumerate(sublist) if value == 1]
48
+ result.append(names)
49
+
50
+ return result
51
+
52
+ @st.cache_resource
53
+ def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):
54
+ """
55
+ loads the document classifier using haystack, where the name/path of model
56
+ in HF-hub as string is used to fetch the model object.Either configfile or
57
+ model should be passed.
58
+ 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
59
+ 2. https://docs.haystack.deepset.ai/docs/document_classifier
60
+ Params
61
+ --------
62
+ config_file: config file path from which to read the model name
63
+ classifier_name: if modelname is passed, it takes a priority if not \
64
+ found then will look for configfile, else raise error.
65
+ Return: document classifier model
66
+ """
67
+
68
+ # If no classifier given
69
+
70
+ if not classifier_name:
71
+ if not config_file:
72
+ logging.warning("Pass either model name or config file")
73
+ return
74
+ else:
75
+ config = getconfig(config_file)
76
+ classifier_name = config.get('vulnerability','MODEL')
77
+
78
+ logging.info("Loading vulnerability classifier")
79
+
80
+ # we are using the pipeline as the model is multilabel and DocumentClassifier
81
+ # from Haystack doesnt support multilabel
82
+ # in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel
83
+ # if not then it will automatically use softmax, which is not a desired thing.
84
+ # doc_classifier = TransformersDocumentClassifier(
85
+ # model_name_or_path=classifier_name,
86
+ # task="text-classification",
87
+ # top_k = None)
88
+
89
+ # Download model from HF Hub
90
+ doc_classifier = SetFitModel.from_pretrained(classifier_name)
91
+
92
+
93
+ # doc_classifier = pipeline("text-classification",
94
+ # model=classifier_name,
95
+ # return_all_scores=True,
96
+ # function_to_apply= "sigmoid")
97
+
98
+ return doc_classifier
99
+
100
+
101
+ @st.cache_data
102
+ def vulnerability_classification(haystack_doc:pd.DataFrame,
103
+ threshold:float = 0.5,
104
+ classifier_model:pipeline= None
105
+ )->Tuple[DataFrame,Series]:
106
+ """
107
+ Text-Classification on the list of texts provided. Classifier provides the
108
+ most appropriate label for each text. these labels are in terms of if text
109
+ reference a group in a vulnerable situation.
110
+ ---------
111
+ haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
112
+ contains the list of paragraphs in different format,here the list of
113
+ Haystack Documents is used.
114
+ threshold: threshold value for the model to keep the results from classifier
115
+ classifiermodel: you can pass the classifier model directly,which takes priority
116
+ however if not then looks for model in streamlit session.
117
+ In case of streamlit avoid passing the model directly.
118
+ Returns
119
+ ----------
120
+ df: Dataframe with two columns['SDG:int', 'text']
121
+ x: Series object with the unique SDG covered in the document uploaded and
122
+ the number of times it is covered/discussed/count_of_paragraphs.
123
+ """
124
+ logging.info("Working on vulnerability Identification")
125
+ haystack_doc['Vulnerability Label'] = 'NA'
126
+ # haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
127
+
128
+ # df1 = haystack_doc[haystack_doc['PA_check'] == True]
129
+ # df = haystack_doc[haystack_doc['PA_check'] == False]
130
+ if not classifier_model:
131
+ classifier_model = st.session_state['vulnerability_classifier']
132
+
133
+ predictions = classifier_model(list(haystack_doc.text))
134
+
135
+
136
+
137
+ pred_labels = get_vulnerability_labels(predictions)
138
+
139
+ haystack_doc['Vulnerability Label'] = pred_labels
140
+ # placeholder = {}
141
+ # for j in range(len(temp)):
142
+ # placeholder[temp[j]['label']] = temp[j]['score']
143
+ # list_.append(placeholder)
144
+ # labels_ = [{**list_[l]} for l in range(len(predictions))]
145
+ # truth_df = DataFrame.from_dict(labels_)
146
+ # truth_df = truth_df.round(2)
147
+ # truth_df = truth_df.astype(float) >= threshold
148
+ # truth_df = truth_df.astype(str)
149
+ # categories = list(truth_df.columns)
150
+ # truth_df['Vulnerability Label'] = truth_df.apply(lambda x: {i if x[i]=='True' else
151
+ # None for i in categories}, axis=1)
152
+ # truth_df['Vulnerability Label'] = truth_df.apply(lambda x: list(x['Vulnerability Label']
153
+ # -{None}),axis=1)
154
+ # haystack_doc['Vulnerability Label'] = list(truth_df['Vulnerability Label'])
155
+ return haystack_doc