Spaces:
Runtime error
Runtime error
import numpy as np | |
import csv | |
from typing import Optional | |
from urllib.request import urlopen | |
import gradio as gr | |
class SentimentTransform(): | |
def __init__( | |
self, | |
model_name: str = "cardiffnlp/twitter-roberta-base-sentiment", | |
highlight: bool = False, | |
positive_sentiment_name: str = "positive", | |
max_number_of_shap_documents: Optional[int] = None, | |
min_abs_score: float = 0.1, | |
sensitivity: float = 0, | |
**kwargs, | |
): | |
""" | |
Sentiment Ops. | |
Parameters | |
------------- | |
model_name: str | |
The name of the model | |
sensitivity: float | |
How confident it is about being `neutral`. If you are dealing with news sources, | |
you probably want less sensitivity | |
""" | |
self.model_name = model_name | |
self.highlight = highlight | |
self.positive_sentiment_name = positive_sentiment_name | |
self.max_number_of_shap_documents = max_number_of_shap_documents | |
self.min_abs_score = min_abs_score | |
self.sensitivity = sensitivity | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
def preprocess(self, text: str): | |
new_text = [] | |
for t in text.split(" "): | |
t = "@user" if t.startswith("@") and len(t) > 1 else t | |
t = "http" if t.startswith("http") else t | |
new_text.append(t) | |
return " ".join(new_text) | |
def classifier(self): | |
if not hasattr(self, "_classifier"): | |
import transformers | |
self._classifier = transformers.pipeline( | |
return_all_scores=True, | |
model=self.model_name, | |
) | |
return self._classifier | |
def _get_label_mapping(self, task: str): | |
# Note: this is specific to the current model | |
labels = [] | |
mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt" | |
with urlopen(mapping_link) as f: | |
html = f.read().decode("utf-8").split("\n") | |
csvreader = csv.reader(html, delimiter="\t") | |
labels = [row[1] for row in csvreader if len(row) > 1] | |
return labels | |
def label_mapping(self): | |
return {"LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"} | |
def analyze_sentiment( | |
self, | |
text, | |
highlight: bool = False, | |
positive_sentiment_name: str = "positive", | |
max_number_of_shap_documents: Optional[int] = None, | |
min_abs_score: float = 0.1, | |
): | |
if text is None: | |
return None | |
labels = self.classifier([str(text)], truncation=True, max_length=512) | |
ind_max = np.argmax([l["score"] for l in labels[0]]) | |
sentiment = labels[0][ind_max]["label"] | |
max_score = labels[0][ind_max]["score"] | |
sentiment = self.label_mapping.get(sentiment, sentiment) | |
if sentiment.lower() == "neutral" and max_score > self.sensitivity: | |
overall_sentiment = 1e-5 | |
elif sentiment.lower() == "neutral": | |
# get the next highest score | |
new_labels = labels[0][:ind_max] + labels[0][(ind_max + 1):] | |
new_ind_max = np.argmax([l["score"] for l in new_labels]) | |
new_max_score = new_labels[new_ind_max]["score"] | |
new_sentiment = new_labels[new_ind_max]["label"] | |
new_sentiment = self.label_mapping.get(new_sentiment, new_sentiment) | |
overall_sentiment = self._calculate_overall_sentiment( | |
new_max_score, new_sentiment | |
) | |
else: | |
overall_sentiment = self._calculate_overall_sentiment(max_score, sentiment) | |
# Adjust to avoid bug | |
if overall_sentiment == 0: | |
overall_sentiment = 1e-5 | |
if not highlight: | |
return { | |
"sentiment": sentiment, | |
"overall_sentiment_score": overall_sentiment, | |
} | |
shap_documents = self.get_shap_values( | |
text, | |
sentiment_ind=ind_max, | |
max_number_of_shap_documents=max_number_of_shap_documents, | |
min_abs_score=min_abs_score, | |
) | |
return { | |
"sentiment": sentiment, | |
"score": max_score, | |
"overall_sentiment": overall_sentiment, | |
"highlight_chunk_": shap_documents, | |
} | |
def _calculate_overall_sentiment(self, score: float, sentiment: str): | |
if sentiment.lower().strip() == self.positive_sentiment_name: | |
return score | |
else: | |
return -score | |
# def explainer(self): | |
# if hasattr(self, "_explainer"): | |
# return self._explainer | |
# else: | |
# try: | |
# import shap | |
# except ModuleNotFoundError: | |
# raise MissingPackageError("shap") | |
# self._explainer = shap.Explainer(self.classifier) | |
# return self._explainer | |
def get_shap_values( | |
self, | |
text: str, | |
sentiment_ind: int = 2, | |
max_number_of_shap_documents: Optional[int] = None, | |
min_abs_score: float = 0.1, | |
): | |
"""Get SHAP values""" | |
shap_values = self.explainer([text]) | |
cohorts = {"": shap_values} | |
cohort_labels = list(cohorts.keys()) | |
cohort_exps = list(cohorts.values()) | |
features = cohort_exps[0].data | |
feature_names = cohort_exps[0].feature_names | |
values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))]) | |
shap_docs = [ | |
{"text": v, "score": f} | |
for f, v in zip( | |
[x[sentiment_ind] for x in values[0][0].tolist()], feature_names[0] | |
) | |
] | |
if max_number_of_shap_documents is not None: | |
sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True) | |
else: | |
sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)[ | |
:max_number_of_shap_documents | |
] | |
return [d for d in sorted_scores if abs(d["score"]) > min_abs_score] | |
def transform(self, text): | |
# # For each document, update the field | |
# sentiment_docs = [{"_id": d["_id"]} for d in documents] | |
# for i, t in enumerate(self.text_fields): | |
# if self.output_fields is not None: | |
# output_field = self.output_fields[i] | |
# else: | |
# output_field = self._get_output_field(t) | |
sentiment = self.analyze_sentiment( | |
text, | |
highlight=self.highlight, | |
max_number_of_shap_documents=self.max_number_of_shap_documents, | |
min_abs_score=self.min_abs_score, ) | |
return sentiment | |
def sentiment_classifier(text, model_type, sensitivity): | |
if model_type == 'Social Media Model': | |
model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
elif model_type == 'Survey Model': | |
model_name = "j-hartmann/sentiment-roberta-large-english-3-classes" | |
else: | |
model_name = "j-hartmann/sentiment-roberta-large-english-3-classes" | |
model = SentimentTransform(model_name=model_name, sensitivity=sensitivity) | |
res_dict = model.transform(text) | |
return res_dict['sentiment'], res_dict['overall_sentiment_score'] | |
demo = gr.Interface( | |
fn=sentiment_classifier, | |
inputs=[gr.Textbox(placeholder="Put the text here and click 'submit' to predict its sentiment", label="Input Text"), gr.Dropdown(["Social Media Model", "Survey Model"], value="Survey Model", label="Select the Model that you want to use."), gr.Slider(0, 1, step = 0.01, label="Sensitivity (How confident it is about being `neutral`. If you are dealing with news sources, you probably want less sensitivity.)")], | |
outputs=[gr.Textbox(label='Sentiment'), gr.Textbox(label='Sentiment Score')], | |
) | |
demo.launch(debug=True) |