Spaces:
Running
on
T4
Running
on
T4
adding appstore and utils main scripts
Browse files- app.py +30 -1
- appStore/__init__.py +1 -0
- appStore/doc_processing.py +80 -0
- appStore/target.py +111 -0
- appStore/vulnerability_analysis.py +169 -0
- utils/target_classifier.py +125 -0
- utils/vulnerability_classifier.py +137 -1
app.py
CHANGED
@@ -2,6 +2,9 @@ import streamlit as st
|
|
2 |
from utils.uploadAndExample import add_upload
|
3 |
from utils.config import model_dict
|
4 |
from utils.vulnerability_classifier import label_dict
|
|
|
|
|
|
|
5 |
|
6 |
with st.sidebar:
|
7 |
# upload and example doc
|
@@ -23,4 +26,30 @@ with st.sidebar:
|
|
23 |
|
24 |
with st.container():
|
25 |
st.markdown("<h2 style='text-align: center;'> Vulnerability Analysis 3.1 </h2>", unsafe_allow_html=True)
|
26 |
-
st.write(' ')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from utils.uploadAndExample import add_upload
|
3 |
from utils.config import model_dict
|
4 |
from utils.vulnerability_classifier import label_dict
|
5 |
+
import appStore.doc_processing as processing
|
6 |
+
import appStore.vulnerability_analysis as vulnerability_analysis
|
7 |
+
import appStore.target as target_analysis
|
8 |
|
9 |
with st.sidebar:
|
10 |
# upload and example doc
|
|
|
26 |
|
27 |
with st.container():
|
28 |
st.markdown("<h2 style='text-align: center;'> Vulnerability Analysis 3.1 </h2>", unsafe_allow_html=True)
|
29 |
+
st.write(' ')
|
30 |
+
|
31 |
+
with st.expander("ℹ️ - About this app", expanded=False):
|
32 |
+
st.write(
|
33 |
+
"""
|
34 |
+
The Vulnerability Analysis App is an open-source\
|
35 |
+
digital tool which aims to assist policy analysts and \
|
36 |
+
other users in extracting and filtering references \
|
37 |
+
to different groups in vulnerable situations from public documents. \
|
38 |
+
We use Natural Language Processing (NLP), specifically deep \
|
39 |
+
learning-based text representations to search context-sensitively \
|
40 |
+
for mentions of the special needs of groups in vulnerable situations
|
41 |
+
to cluster them thematically.
|
42 |
+
For more understanding on Methodology [Click Here](https://vulnerability-analysis.streamlit.app/)
|
43 |
+
""")
|
44 |
+
|
45 |
+
st.write("""
|
46 |
+
What Happens in background?
|
47 |
+
|
48 |
+
- Step 1: Once the document is provided to app, it undergoes *Pre-processing*.\
|
49 |
+
In this step the document is broken into smaller paragraphs \
|
50 |
+
(based on word/sentence count).
|
51 |
+
- Step 2: The paragraphs are then fed to the **Vulnerability Classifier** which detects if
|
52 |
+
the paragraph contains any or multiple references to vulnerable groups.
|
53 |
+
""")
|
54 |
+
|
55 |
+
st.write("")
|
appStore/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# more app related files
|
appStore/doc_processing.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# set path
|
2 |
+
import glob, os, sys;
|
3 |
+
sys.path.append('../utils')
|
4 |
+
from typing import List, Tuple
|
5 |
+
from typing_extensions import Literal
|
6 |
+
from haystack.schema import Document
|
7 |
+
from utils.config import get_classifier_params
|
8 |
+
from utils.preprocessing import processingpipeline,paraLengthCheck
|
9 |
+
import streamlit as st
|
10 |
+
import logging
|
11 |
+
import pandas as pd
|
12 |
+
import nltk
|
13 |
+
nltk.download('punkt_tab')
|
14 |
+
|
15 |
+
params = get_classifier_params("preprocessing")
|
16 |
+
|
17 |
+
@st.cache_data
|
18 |
+
def runPreprocessingPipeline(file_name:str, file_path:str,
|
19 |
+
split_by: Literal["sentence", "word"] = 'sentence',
|
20 |
+
split_length:int = 2, split_respect_sentence_boundary:bool = False,
|
21 |
+
split_overlap:int = 0,remove_punc:bool = False)->List[Document]:
|
22 |
+
"""
|
23 |
+
creates the pipeline and runs the preprocessing pipeline,
|
24 |
+
the params for pipeline are fetched from paramconfig
|
25 |
+
Params
|
26 |
+
------------
|
27 |
+
file_name: filename, in case of streamlit application use
|
28 |
+
st.session_state['filename']
|
29 |
+
file_path: filepath, in case of streamlit application use st.session_state['filepath']
|
30 |
+
split_by: document splitting strategy either as word or sentence
|
31 |
+
split_length: when synthetically creating the paragrpahs from document,
|
32 |
+
it defines the length of paragraph.
|
33 |
+
split_respect_sentence_boundary: Used when using 'word' strategy for
|
34 |
+
splititng of text.
|
35 |
+
split_overlap: Number of words or sentences that overlap when creating
|
36 |
+
the paragraphs. This is done as one sentence or 'some words' make sense
|
37 |
+
when read in together with others. Therefore the overlap is used.
|
38 |
+
remove_punc: to remove all Punctuation including ',' and '.' or not
|
39 |
+
Return
|
40 |
+
--------------
|
41 |
+
List[Document]: When preprocessing pipeline is run, the output dictionary
|
42 |
+
has four objects. For the Haysatck implementation of SDG classification we,
|
43 |
+
need to use the List of Haystack Document, which can be fetched by
|
44 |
+
key = 'documents' on output.
|
45 |
+
"""
|
46 |
+
|
47 |
+
processing_pipeline = processingpipeline()
|
48 |
+
|
49 |
+
output_pre = processing_pipeline.run(file_paths = file_path,
|
50 |
+
params= {"FileConverter": {"file_path": file_path, \
|
51 |
+
"file_name": file_name},
|
52 |
+
"UdfPreProcessor": {"remove_punc": remove_punc, \
|
53 |
+
"split_by": split_by, \
|
54 |
+
"split_length":split_length,\
|
55 |
+
"split_overlap": split_overlap, \
|
56 |
+
"split_respect_sentence_boundary":split_respect_sentence_boundary}})
|
57 |
+
|
58 |
+
return output_pre
|
59 |
+
|
60 |
+
|
61 |
+
def app():
|
62 |
+
with st.container():
|
63 |
+
if 'filepath' in st.session_state:
|
64 |
+
file_name = st.session_state['filename']
|
65 |
+
file_path = st.session_state['filepath']
|
66 |
+
|
67 |
+
|
68 |
+
all_documents = runPreprocessingPipeline(file_name= file_name,
|
69 |
+
file_path= file_path, split_by= params['split_by'],
|
70 |
+
split_length= params['split_length'],
|
71 |
+
split_respect_sentence_boundary= params['split_respect_sentence_boundary'],
|
72 |
+
split_overlap= params['split_overlap'], remove_punc= params['remove_punc'])
|
73 |
+
paralist = paraLengthCheck(all_documents['documents'], 100)
|
74 |
+
df = pd.DataFrame(paralist,columns = ['text','page'])
|
75 |
+
# saving the dataframe to session state
|
76 |
+
st.session_state['key0'] = df
|
77 |
+
|
78 |
+
else:
|
79 |
+
st.info("🤔 No document found, please try to upload it at the sidebar!")
|
80 |
+
logging.warning("Terminated as no document provided")
|
appStore/target.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# set path
|
2 |
+
import glob, os, sys;
|
3 |
+
sys.path.append('../utils')
|
4 |
+
|
5 |
+
#import needed libraries
|
6 |
+
import seaborn as sns
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import streamlit as st
|
11 |
+
from utils.target_classifier import load_targetClassifier, target_classification
|
12 |
+
import logging
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
from utils.config import get_classifier_params
|
15 |
+
from utils.preprocessing import paraLengthCheck
|
16 |
+
from io import BytesIO
|
17 |
+
import xlsxwriter
|
18 |
+
import plotly.express as px
|
19 |
+
from utils.target_classifier import label_dict
|
20 |
+
from appStore.rag import run_query
|
21 |
+
|
22 |
+
# Declare all the necessary variables
|
23 |
+
classifier_identifier = 'target'
|
24 |
+
params = get_classifier_params(classifier_identifier)
|
25 |
+
|
26 |
+
@st.cache_data
|
27 |
+
def to_excel(df,sectorlist):
|
28 |
+
len_df = len(df)
|
29 |
+
output = BytesIO()
|
30 |
+
writer = pd.ExcelWriter(output, engine='xlsxwriter')
|
31 |
+
df.to_excel(writer, index=False, sheet_name='Sheet1')
|
32 |
+
workbook = writer.book
|
33 |
+
worksheet = writer.sheets['Sheet1']
|
34 |
+
worksheet.data_validation('S2:S{}'.format(len_df),
|
35 |
+
{'validate': 'list',
|
36 |
+
'source': ['No', 'Yes', 'Discard']})
|
37 |
+
worksheet.data_validation('X2:X{}'.format(len_df),
|
38 |
+
{'validate': 'list',
|
39 |
+
'source': sectorlist + ['Blank']})
|
40 |
+
worksheet.data_validation('T2:T{}'.format(len_df),
|
41 |
+
{'validate': 'list',
|
42 |
+
'source': sectorlist + ['Blank']})
|
43 |
+
worksheet.data_validation('U2:U{}'.format(len_df),
|
44 |
+
{'validate': 'list',
|
45 |
+
'source': sectorlist + ['Blank']})
|
46 |
+
worksheet.data_validation('V2:V{}'.format(len_df),
|
47 |
+
{'validate': 'list',
|
48 |
+
'source': sectorlist + ['Blank']})
|
49 |
+
worksheet.data_validation('W2:U{}'.format(len_df),
|
50 |
+
{'validate': 'list',
|
51 |
+
'source': sectorlist + ['Blank']})
|
52 |
+
writer.save()
|
53 |
+
processed_data = output.getvalue()
|
54 |
+
return processed_data
|
55 |
+
|
56 |
+
def app():
|
57 |
+
|
58 |
+
### Main app code ###
|
59 |
+
with st.container():
|
60 |
+
|
61 |
+
if 'key1' in st.session_state:
|
62 |
+
|
63 |
+
# Load the existing dataset
|
64 |
+
df = st.session_state.key1
|
65 |
+
|
66 |
+
# Filter out all paragraphs that do not have a reference to groups
|
67 |
+
df = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]
|
68 |
+
|
69 |
+
# Load the classifier model
|
70 |
+
classifier = load_targetClassifier(classifier_name=params['model_name'])
|
71 |
+
|
72 |
+
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
|
73 |
+
|
74 |
+
df = target_classification(haystack_doc=df,
|
75 |
+
threshold= params['threshold'])
|
76 |
+
|
77 |
+
# Rename column
|
78 |
+
df.rename(columns={'Target Label': 'Specific action/target/measure mentioned'}, inplace=True)
|
79 |
+
|
80 |
+
|
81 |
+
st.session_state.key2 = df
|
82 |
+
|
83 |
+
|
84 |
+
def target_display(model_sel_name):
|
85 |
+
|
86 |
+
### TABLE Output ###
|
87 |
+
|
88 |
+
# Assign dataframe a name
|
89 |
+
df = st.session_state['key2']
|
90 |
+
st.write(df)
|
91 |
+
|
92 |
+
### RAG Output by group ##
|
93 |
+
|
94 |
+
# Expand the DataFrame
|
95 |
+
df_expand = (
|
96 |
+
df.query("`Specific action/target/measure mentioned` == 'YES'")
|
97 |
+
.explode('Vulnerability Label')
|
98 |
+
)
|
99 |
+
# Group by 'Vulnerability Label' and concatenate 'text'
|
100 |
+
df_agg = df_expand.groupby('Vulnerability Label')['text'].agg('; '.join).reset_index()
|
101 |
+
|
102 |
+
# st.write(df_agg)
|
103 |
+
|
104 |
+
st.markdown("----")
|
105 |
+
st.markdown('**DOCUMENT FINDINGS SUMMARY BY VULNERABILITY LABEL:**')
|
106 |
+
|
107 |
+
# construct RAG query for each label, send to openai and process response
|
108 |
+
for i in range(0,len(df_agg)):
|
109 |
+
st.write(df_agg['Vulnerability Label'].iloc[i])
|
110 |
+
run_query(context = df_agg['text'].iloc[i], label = df_agg['Vulnerability Label'].iloc[i], model_sel_name=model_sel_name)
|
111 |
+
|
appStore/vulnerability_analysis.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# set path
|
2 |
+
import glob, os, sys;
|
3 |
+
sys.path.append('../utils')
|
4 |
+
|
5 |
+
#import needed libraries
|
6 |
+
import seaborn as sns
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import streamlit as st
|
11 |
+
from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification
|
12 |
+
import logging
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
from utils.config import get_classifier_params
|
15 |
+
from utils.preprocessing import paraLengthCheck
|
16 |
+
from io import BytesIO
|
17 |
+
import xlsxwriter
|
18 |
+
import plotly.express as px
|
19 |
+
import plotly.graph_objects as go
|
20 |
+
from utils.vulnerability_classifier import label_dict
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# Declare all the necessary variables
|
25 |
+
classifier_identifier = 'vulnerability'
|
26 |
+
params = get_classifier_params(classifier_identifier)
|
27 |
+
|
28 |
+
@st.cache_data
|
29 |
+
def to_excel(df,sectorlist):
|
30 |
+
len_df = len(df)
|
31 |
+
output = BytesIO()
|
32 |
+
writer = pd.ExcelWriter(output, engine='xlsxwriter')
|
33 |
+
df.to_excel(writer, index=False, sheet_name='Sheet1')
|
34 |
+
workbook = writer.book
|
35 |
+
worksheet = writer.sheets['Sheet1']
|
36 |
+
worksheet.data_validation('S2:S{}'.format(len_df),
|
37 |
+
{'validate': 'list',
|
38 |
+
'source': ['No', 'Yes', 'Discard']})
|
39 |
+
worksheet.data_validation('X2:X{}'.format(len_df),
|
40 |
+
{'validate': 'list',
|
41 |
+
'source': sectorlist + ['Blank']})
|
42 |
+
worksheet.data_validation('T2:T{}'.format(len_df),
|
43 |
+
{'validate': 'list',
|
44 |
+
'source': sectorlist + ['Blank']})
|
45 |
+
worksheet.data_validation('U2:U{}'.format(len_df),
|
46 |
+
{'validate': 'list',
|
47 |
+
'source': sectorlist + ['Blank']})
|
48 |
+
worksheet.data_validation('V2:V{}'.format(len_df),
|
49 |
+
{'validate': 'list',
|
50 |
+
'source': sectorlist + ['Blank']})
|
51 |
+
worksheet.data_validation('W2:U{}'.format(len_df),
|
52 |
+
{'validate': 'list',
|
53 |
+
'source': sectorlist + ['Blank']})
|
54 |
+
writer.save()
|
55 |
+
processed_data = output.getvalue()
|
56 |
+
return processed_data
|
57 |
+
|
58 |
+
def app():
|
59 |
+
|
60 |
+
### Main app code ###
|
61 |
+
with st.container():
|
62 |
+
|
63 |
+
# If a document has been processed
|
64 |
+
if 'key0' in st.session_state:
|
65 |
+
|
66 |
+
# Run vulnerability classifier
|
67 |
+
df = st.session_state.key0
|
68 |
+
classifier = load_vulnerabilityClassifier(classifier_name=params['model_name'])
|
69 |
+
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
|
70 |
+
|
71 |
+
|
72 |
+
# Get the predictions
|
73 |
+
df = vulnerability_classification(haystack_doc=df,
|
74 |
+
threshold= params['threshold'])
|
75 |
+
|
76 |
+
|
77 |
+
# Store df in session state with key1
|
78 |
+
st.session_state.key1 = df
|
79 |
+
|
80 |
+
|
81 |
+
def vulnerability_display():
|
82 |
+
|
83 |
+
# Get the vulnerability df
|
84 |
+
df = st.session_state['key1']
|
85 |
+
|
86 |
+
# Filter the dataframe to only show the paragraphs with references
|
87 |
+
df_filtered = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]
|
88 |
+
|
89 |
+
# Rename column
|
90 |
+
df_filtered.rename(columns={'Vulnerability Label': 'Group(s)'}, inplace=True)
|
91 |
+
|
92 |
+
# Header
|
93 |
+
st.subheader("Explore references to vulnerable groups:")
|
94 |
+
|
95 |
+
|
96 |
+
# Text
|
97 |
+
num_paragraphs = len(df['Vulnerability Label'])
|
98 |
+
num_references = len(df_filtered['Group(s)'])
|
99 |
+
|
100 |
+
st.markdown(f"""<div style="text-align: justify;">The document contains a
|
101 |
+
total of <span style="color: red;">{num_paragraphs}</span> paragraphs.
|
102 |
+
We identified <span style="color: red;">{num_references}</span>
|
103 |
+
references to groups in vulnerable situations.</div>
|
104 |
+
<br>
|
105 |
+
<div style="text-align: justify;">We are searching for references related
|
106 |
+
to the following groups: (1) Agricultural communities, (2) Children, (3)
|
107 |
+
Ethnic, racial and other minorities, (4) Fishery communities, (5) Informal sector
|
108 |
+
workers, (6) Members of indigenous and local communities, (7) Migrants and
|
109 |
+
displaced persons, (8) Older persons, (9) Persons living in poverty, (10)
|
110 |
+
Persons living with disabilities, (11) Persons with pre-existing health conditions,
|
111 |
+
(12) Residents of drought-prone regions, (13) Rural populations, (14) Sexual
|
112 |
+
minorities (LGBTQI+), (15) Urban populations, (16) Women and other genders.</div>
|
113 |
+
<br>
|
114 |
+
<div style="text-align: justify;">The chart below shows the groups for which
|
115 |
+
references were found and the number of references identified.
|
116 |
+
For a more detailed view in the text, see the paragraphs and
|
117 |
+
their respective labels in the table underneath.</div>""", unsafe_allow_html=True)
|
118 |
+
|
119 |
+
|
120 |
+
### Bar chart
|
121 |
+
|
122 |
+
# # Create a df that stores all the labels
|
123 |
+
df_labels = pd.DataFrame(list(label_dict.items()), columns=['Label ID', 'Label'])
|
124 |
+
|
125 |
+
# Count how often each label appears in the "Group identified" column
|
126 |
+
group_counts = {}
|
127 |
+
|
128 |
+
# Iterate through each sublist
|
129 |
+
for index, row in df_filtered.iterrows():
|
130 |
+
|
131 |
+
# Iterate through each group in the sublist
|
132 |
+
for sublist in row['Group(s)']:
|
133 |
+
|
134 |
+
# Update the count in the dictionary
|
135 |
+
group_counts[sublist] = group_counts.get(sublist, 0) + 1
|
136 |
+
|
137 |
+
# Create a new dataframe from group_counts
|
138 |
+
df_label_count = pd.DataFrame(list(group_counts.items()), columns=['Label', 'Count'])
|
139 |
+
|
140 |
+
# Merge the label counts with the df_label DataFrame
|
141 |
+
df_label_count = df_labels.merge(df_label_count, on='Label', how='left')
|
142 |
+
|
143 |
+
# Exclude the "Other" group and all groups that do not have a label
|
144 |
+
df_bar_chart = df_label_count[df_label_count['Label'] != 'Other']
|
145 |
+
df_bar_chart = df_bar_chart.dropna(subset=['Count'])
|
146 |
+
|
147 |
+
|
148 |
+
# Bar chart
|
149 |
+
fig = go.Figure()
|
150 |
+
|
151 |
+
fig.add_trace(go.Bar(
|
152 |
+
y=df_bar_chart.Label,
|
153 |
+
x=df_bar_chart.Count,
|
154 |
+
orientation='h',
|
155 |
+
marker=dict(color='purple'),
|
156 |
+
))
|
157 |
+
|
158 |
+
# Customize layout
|
159 |
+
fig.update_layout(
|
160 |
+
title='Number of references identified',
|
161 |
+
xaxis_title='Number of references',
|
162 |
+
yaxis_title='Group',
|
163 |
+
)
|
164 |
+
|
165 |
+
# Show the plot
|
166 |
+
#fig.show()
|
167 |
+
|
168 |
+
#Show plot
|
169 |
+
st.plotly_chart(fig, use_container_width=True)
|
utils/target_classifier.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
from typing_extensions import Literal
|
3 |
+
import logging
|
4 |
+
import pandas as pd
|
5 |
+
from pandas import DataFrame, Series
|
6 |
+
from utils.config import getconfig
|
7 |
+
from utils.preprocessing import processingpipeline
|
8 |
+
import streamlit as st
|
9 |
+
from setfit import SetFitModel
|
10 |
+
from transformers import pipeline
|
11 |
+
|
12 |
+
## Labels dictionary ###
|
13 |
+
label_dict = {
|
14 |
+
0:'NO',
|
15 |
+
1:'YES',
|
16 |
+
}
|
17 |
+
|
18 |
+
def get_target_labels(preds):
|
19 |
+
|
20 |
+
"""
|
21 |
+
Function that takes the numerical predictions as an input and returns a list of the labels.
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Turn into list
|
26 |
+
preds_list = preds.numpy().tolist()
|
27 |
+
|
28 |
+
|
29 |
+
# Get label names
|
30 |
+
predictions_names=[]
|
31 |
+
|
32 |
+
# loop through each prediction
|
33 |
+
for ele in preds_list:
|
34 |
+
|
35 |
+
# see if there is a value 1 and retrieve index
|
36 |
+
try:
|
37 |
+
index_of_one = ele.index(1)
|
38 |
+
except ValueError:
|
39 |
+
index_of_one = "NA"
|
40 |
+
|
41 |
+
# Retrieve the name of the label (if no prediction made = NA)
|
42 |
+
if index_of_one != "NA":
|
43 |
+
name = label_dict[index_of_one]
|
44 |
+
else:
|
45 |
+
name = "Other"
|
46 |
+
|
47 |
+
# Append name to list
|
48 |
+
predictions_names.append(name)
|
49 |
+
|
50 |
+
return predictions_names
|
51 |
+
|
52 |
+
@st.cache_resource
|
53 |
+
def load_targetClassifier(config_file:str = None, classifier_name:str = None):
|
54 |
+
"""
|
55 |
+
loads the document classifier using haystack, where the name/path of model
|
56 |
+
in HF-hub as string is used to fetch the model object.Either configfile or
|
57 |
+
model should be passed.
|
58 |
+
1. https://docs.haystack.deepset.ai/reference/document-classifier-api
|
59 |
+
2. https://docs.haystack.deepset.ai/docs/document_classifier
|
60 |
+
Params
|
61 |
+
--------
|
62 |
+
config_file: config file path from which to read the model name
|
63 |
+
classifier_name: if modelname is passed, it takes a priority if not \
|
64 |
+
found then will look for configfile, else raise error.
|
65 |
+
Return: document classifier model
|
66 |
+
"""
|
67 |
+
if not classifier_name:
|
68 |
+
if not config_file:
|
69 |
+
logging.warning("Pass either model name or config file")
|
70 |
+
return
|
71 |
+
else:
|
72 |
+
config = getconfig(config_file)
|
73 |
+
classifier_name = config.get('target','MODEL')
|
74 |
+
|
75 |
+
logging.info("Loading classifier")
|
76 |
+
|
77 |
+
# Loading classifier
|
78 |
+
doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_target")
|
79 |
+
|
80 |
+
return doc_classifier
|
81 |
+
|
82 |
+
|
83 |
+
@st.cache_data
|
84 |
+
def target_classification(haystack_doc:pd.DataFrame,
|
85 |
+
threshold:float = 0.5,
|
86 |
+
classifier_model:pipeline= None
|
87 |
+
)->Tuple[DataFrame,Series]:
|
88 |
+
"""
|
89 |
+
Text-Classification on the list of texts provided. Classifier provides the
|
90 |
+
most appropriate label for each text. There labels indicate whether the paragraph
|
91 |
+
references a specific action, target or measure in the paragraph.
|
92 |
+
---------
|
93 |
+
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
|
94 |
+
contains the list of paragraphs in different format,here the list of
|
95 |
+
Haystack Documents is used.
|
96 |
+
threshold: threshold value for the model to keep the results from classifier
|
97 |
+
classifiermodel: you can pass the classifier model directly,which takes priority
|
98 |
+
however if not then looks for model in streamlit session.
|
99 |
+
In case of streamlit avoid passing the model directly.
|
100 |
+
Returns
|
101 |
+
----------
|
102 |
+
df: Dataframe with two columns['SDG:int', 'text']
|
103 |
+
x: Series object with the unique SDG covered in the document uploaded and
|
104 |
+
the number of times it is covered/discussed/count_of_paragraphs.
|
105 |
+
"""
|
106 |
+
|
107 |
+
logging.info("Working on target/action identification")
|
108 |
+
|
109 |
+
haystack_doc['Target Label'] = 'NA'
|
110 |
+
|
111 |
+
if not classifier_model:
|
112 |
+
|
113 |
+
classifier_model = st.session_state['target_classifier']
|
114 |
+
|
115 |
+
# Get predictions
|
116 |
+
predictions = classifier_model(list(haystack_doc.text))
|
117 |
+
|
118 |
+
# Get labels for predictions
|
119 |
+
pred_labels = get_target_labels(predictions)
|
120 |
+
|
121 |
+
# Save labels
|
122 |
+
haystack_doc['Target Label'] = pred_labels
|
123 |
+
|
124 |
+
return haystack_doc
|
125 |
+
|
utils/vulnerability_classifier.py
CHANGED
@@ -1,3 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# labels dictionary
|
2 |
label_dict= {0: 'Agricultural communities',
|
3 |
1: 'Children',
|
@@ -16,4 +28,128 @@ label_dict= {0: 'Agricultural communities',
|
|
16 |
14: 'Rural populations',
|
17 |
15: 'Sexual minorities (LGBTQI+)',
|
18 |
16: 'Urban populations',
|
19 |
-
17: 'Women and other genders'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
from typing_extensions import Literal
|
3 |
+
import logging
|
4 |
+
import pandas as pd
|
5 |
+
from pandas import DataFrame, Series
|
6 |
+
from utils.config import getconfig
|
7 |
+
from utils.preprocessing import processingpipeline
|
8 |
+
import streamlit as st
|
9 |
+
from transformers import pipeline
|
10 |
+
from setfit import SetFitModel
|
11 |
+
|
12 |
+
|
13 |
# labels dictionary
|
14 |
label_dict= {0: 'Agricultural communities',
|
15 |
1: 'Children',
|
|
|
28 |
14: 'Rural populations',
|
29 |
15: 'Sexual minorities (LGBTQI+)',
|
30 |
16: 'Urban populations',
|
31 |
+
17: 'Women and other genders'}
|
32 |
+
|
33 |
+
def get_vulnerability_labels(preds):
|
34 |
+
|
35 |
+
"""
|
36 |
+
Function that takes the numerical predictions as an input and returns a list of the labels.
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
# Get label names
|
41 |
+
preds_list = preds.tolist()
|
42 |
+
|
43 |
+
# Get the name of the group where the prediction is equal to "1"
|
44 |
+
result = []
|
45 |
+
|
46 |
+
for sublist in preds_list:
|
47 |
+
names = [label_dict[key] for key, value in enumerate(sublist) if value == 1]
|
48 |
+
result.append(names)
|
49 |
+
|
50 |
+
return result
|
51 |
+
|
52 |
+
@st.cache_resource
|
53 |
+
def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None):
|
54 |
+
"""
|
55 |
+
loads the document classifier using haystack, where the name/path of model
|
56 |
+
in HF-hub as string is used to fetch the model object.Either configfile or
|
57 |
+
model should be passed.
|
58 |
+
1. https://docs.haystack.deepset.ai/reference/document-classifier-api
|
59 |
+
2. https://docs.haystack.deepset.ai/docs/document_classifier
|
60 |
+
Params
|
61 |
+
--------
|
62 |
+
config_file: config file path from which to read the model name
|
63 |
+
classifier_name: if modelname is passed, it takes a priority if not \
|
64 |
+
found then will look for configfile, else raise error.
|
65 |
+
Return: document classifier model
|
66 |
+
"""
|
67 |
+
|
68 |
+
# If no classifier given
|
69 |
+
|
70 |
+
if not classifier_name:
|
71 |
+
if not config_file:
|
72 |
+
logging.warning("Pass either model name or config file")
|
73 |
+
return
|
74 |
+
else:
|
75 |
+
config = getconfig(config_file)
|
76 |
+
classifier_name = config.get('vulnerability','MODEL')
|
77 |
+
|
78 |
+
logging.info("Loading vulnerability classifier")
|
79 |
+
|
80 |
+
# we are using the pipeline as the model is multilabel and DocumentClassifier
|
81 |
+
# from Haystack doesnt support multilabel
|
82 |
+
# in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel
|
83 |
+
# if not then it will automatically use softmax, which is not a desired thing.
|
84 |
+
# doc_classifier = TransformersDocumentClassifier(
|
85 |
+
# model_name_or_path=classifier_name,
|
86 |
+
# task="text-classification",
|
87 |
+
# top_k = None)
|
88 |
+
|
89 |
+
# Download model from HF Hub
|
90 |
+
doc_classifier = SetFitModel.from_pretrained(classifier_name)
|
91 |
+
|
92 |
+
|
93 |
+
# doc_classifier = pipeline("text-classification",
|
94 |
+
# model=classifier_name,
|
95 |
+
# return_all_scores=True,
|
96 |
+
# function_to_apply= "sigmoid")
|
97 |
+
|
98 |
+
return doc_classifier
|
99 |
+
|
100 |
+
|
101 |
+
@st.cache_data
|
102 |
+
def vulnerability_classification(haystack_doc:pd.DataFrame,
|
103 |
+
threshold:float = 0.5,
|
104 |
+
classifier_model:pipeline= None
|
105 |
+
)->Tuple[DataFrame,Series]:
|
106 |
+
"""
|
107 |
+
Text-Classification on the list of texts provided. Classifier provides the
|
108 |
+
most appropriate label for each text. these labels are in terms of if text
|
109 |
+
reference a group in a vulnerable situation.
|
110 |
+
---------
|
111 |
+
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
|
112 |
+
contains the list of paragraphs in different format,here the list of
|
113 |
+
Haystack Documents is used.
|
114 |
+
threshold: threshold value for the model to keep the results from classifier
|
115 |
+
classifiermodel: you can pass the classifier model directly,which takes priority
|
116 |
+
however if not then looks for model in streamlit session.
|
117 |
+
In case of streamlit avoid passing the model directly.
|
118 |
+
Returns
|
119 |
+
----------
|
120 |
+
df: Dataframe with two columns['SDG:int', 'text']
|
121 |
+
x: Series object with the unique SDG covered in the document uploaded and
|
122 |
+
the number of times it is covered/discussed/count_of_paragraphs.
|
123 |
+
"""
|
124 |
+
logging.info("Working on vulnerability Identification")
|
125 |
+
haystack_doc['Vulnerability Label'] = 'NA'
|
126 |
+
# haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False)
|
127 |
+
|
128 |
+
# df1 = haystack_doc[haystack_doc['PA_check'] == True]
|
129 |
+
# df = haystack_doc[haystack_doc['PA_check'] == False]
|
130 |
+
if not classifier_model:
|
131 |
+
classifier_model = st.session_state['vulnerability_classifier']
|
132 |
+
|
133 |
+
predictions = classifier_model(list(haystack_doc.text))
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
pred_labels = get_vulnerability_labels(predictions)
|
138 |
+
|
139 |
+
haystack_doc['Vulnerability Label'] = pred_labels
|
140 |
+
# placeholder = {}
|
141 |
+
# for j in range(len(temp)):
|
142 |
+
# placeholder[temp[j]['label']] = temp[j]['score']
|
143 |
+
# list_.append(placeholder)
|
144 |
+
# labels_ = [{**list_[l]} for l in range(len(predictions))]
|
145 |
+
# truth_df = DataFrame.from_dict(labels_)
|
146 |
+
# truth_df = truth_df.round(2)
|
147 |
+
# truth_df = truth_df.astype(float) >= threshold
|
148 |
+
# truth_df = truth_df.astype(str)
|
149 |
+
# categories = list(truth_df.columns)
|
150 |
+
# truth_df['Vulnerability Label'] = truth_df.apply(lambda x: {i if x[i]=='True' else
|
151 |
+
# None for i in categories}, axis=1)
|
152 |
+
# truth_df['Vulnerability Label'] = truth_df.apply(lambda x: list(x['Vulnerability Label']
|
153 |
+
# -{None}),axis=1)
|
154 |
+
# haystack_doc['Vulnerability Label'] = list(truth_df['Vulnerability Label'])
|
155 |
+
return haystack_doc
|