Spaces:
Running
on
T4
Running
on
T4
from typing import List, Tuple | |
from typing_extensions import Literal | |
import logging | |
import pandas as pd | |
from pandas import DataFrame, Series | |
from utils.config import getconfig | |
from utils.preprocessing import processingpipeline | |
import streamlit as st | |
from setfit import SetFitModel | |
from transformers import pipeline | |
## Labels dictionary ### | |
label_dict = { | |
0:'NO', | |
1:'YES', | |
} | |
def get_target_labels(preds): | |
""" | |
Function that takes the numerical predictions as an input and returns a list of the labels. | |
""" | |
# Turn into list | |
preds_list = preds.numpy().tolist() | |
# Get label names | |
predictions_names=[] | |
# loop through each prediction | |
for ele in preds_list: | |
# see if there is a value 1 and retrieve index | |
try: | |
index_of_one = ele.index(1) | |
except ValueError: | |
index_of_one = "NA" | |
# Retrieve the name of the label (if no prediction made = NA) | |
if index_of_one != "NA": | |
name = label_dict[index_of_one] | |
else: | |
name = "Other" | |
# Append name to list | |
predictions_names.append(name) | |
return predictions_names | |
def load_targetClassifier(config_file:str = None, classifier_name:str = None): | |
""" | |
loads the document classifier using haystack, where the name/path of model | |
in HF-hub as string is used to fetch the model object.Either configfile or | |
model should be passed. | |
1. https://docs.haystack.deepset.ai/reference/document-classifier-api | |
2. https://docs.haystack.deepset.ai/docs/document_classifier | |
Params | |
-------- | |
config_file: config file path from which to read the model name | |
classifier_name: if modelname is passed, it takes a priority if not \ | |
found then will look for configfile, else raise error. | |
Return: document classifier model | |
""" | |
if not classifier_name: | |
if not config_file: | |
logging.warning("Pass either model name or config file") | |
return | |
else: | |
config = getconfig(config_file) | |
classifier_name = config.get('target','MODEL') | |
logging.info("Loading classifier") | |
# Loading classifier | |
doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_target") | |
return doc_classifier | |
def target_classification(haystack_doc:pd.DataFrame, | |
threshold:float = 0.5, | |
classifier_model:pipeline= None | |
)->Tuple[DataFrame,Series]: | |
""" | |
Text-Classification on the list of texts provided. Classifier provides the | |
most appropriate label for each text. There labels indicate whether the paragraph | |
references a specific action, target or measure in the paragraph. | |
--------- | |
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline | |
contains the list of paragraphs in different format,here the list of | |
Haystack Documents is used. | |
threshold: threshold value for the model to keep the results from classifier | |
classifiermodel: you can pass the classifier model directly,which takes priority | |
however if not then looks for model in streamlit session. | |
In case of streamlit avoid passing the model directly. | |
Returns | |
---------- | |
df: Dataframe with two columns['SDG:int', 'text'] | |
x: Series object with the unique SDG covered in the document uploaded and | |
the number of times it is covered/discussed/count_of_paragraphs. | |
""" | |
logging.info("Working on target/action identification") | |
haystack_doc['Target Label'] = 'NA' | |
if not classifier_model: | |
classifier_model = st.session_state['target_classifier'] | |
# Get predictions | |
predictions = classifier_model(list(haystack_doc.text)) | |
# Get labels for predictions | |
pred_labels = get_target_labels(predictions) | |
# Save labels | |
haystack_doc['Target Label'] = pred_labels | |
return haystack_doc | |