import torch from transformers import RobertaTokenizer, RobertaForSequenceClassification import gradio as gr # Load the tokenizer and models tokenizer = RobertaTokenizer.from_pretrained("mental/mental-roberta-base") sentiment_model = RobertaForSequenceClassification.from_pretrained("mental/mental-roberta-base") emotion_model = RobertaForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base") # Define the labels sentiment_labels = ["negative", "positive"] emotion_labels = ["anger", "disgust", "fear", "joy", "neutral", "sadness", "surprise"] def analyze_text(text): try: # Tokenize the input text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Get sentiment model outputs sentiment_outputs = sentiment_model(**inputs) sentiment_logits = sentiment_outputs.logits sentiment_probs = torch.nn.functional.softmax(sentiment_logits, dim=-1) # Debugging: Print logits and probs shapes print("Sentiment logits shape:", sentiment_logits.shape) print("Sentiment logits:", sentiment_logits) print("Sentiment probs shape:", sentiment_probs.shape) print("Sentiment probs:", sentiment_probs) # Get the highest probability and corresponding label for sentiment max_sentiment_prob, max_sentiment_index = torch.max(sentiment_probs, dim=1) sentiment = sentiment_labels[max_sentiment_index.item()] # Get emotion model outputs emotion_outputs = emotion_model(**inputs) emotion_logits = emotion_outputs.logits emotion_probs = torch.nn.functional.softmax(emotion_logits, dim=-1) # Debugging: Print logits and probs shapes print("Emotion logits shape:", emotion_logits.shape) print("Emotion logits:", emotion_logits) print("Emotion probs shape:", emotion_probs.shape) print("Emotion probs:", emotion_probs) # Get the highest probability and corresponding label for emotion max_emotion_prob, max_emotion_index = torch.max(emotion_probs, dim=1) emotion = emotion_labels[max_emotion_index.item()] return sentiment, f"{max_sentiment_prob.item():.4f}", emotion, f"{max_emotion_prob.item():.4f}" except Exception as e: print("Error:", str(e)) return "Error", "N/A", "Error", "N/A" # Define the Gradio interface interface = gr.Interface( fn=analyze_text, inputs=gr.Textbox( lines=5, placeholder="Enter text here...", value="I don’t know a lot but what I do know is, we don’t start off very big and we all try to make each other smaller." ), outputs=[ gr.Textbox(label="Detected Sentiment"), gr.Textbox(label="Sentiment Confidence Score"), gr.Textbox(label="Detected Emotion"), gr.Textbox(label="Emotion Confidence Score") ], title="Sentiment and Emotion Analysis: Detecting Positive/Negative Sentiment and Specific Emotions", description="Enter a piece of text to detect overall sentiment (positive or negative) and specific emotions (anger, disgust, fear, joy, neutral, sadness, surprise)." ) # Launch the interface interface.launch()