YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import pandas as pd
# Load the merged model and tokenizer
model_path='POLLCHECK/Llama3.1-bias-sequence-classifier'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
# Function to classify text and return probabilities
def classify_text(text):
# Tokenize the input text and convert it to lower case
inputs = tokenizer(text.lower(), return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device
# Perform inference without gradient calculation
with torch.no_grad():
outputs = model(**inputs)
# Extract logits from the model output
logits = outputs.logits
# Compute probabilities using the softmax function
probabilities = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy()
# Get the index of the class with the highest probability
predicted_class = torch.argmax(logits, dim=1).item()
# Extract the confidence score for the predicted class
confidence = probabilities[predicted_class]
# Map class indices to class labels
class_mapping = {0: "Biased", 1: "Unbiased"}
predicted_label = class_mapping[predicted_class]
return predicted_label, confidence, probabilities
# Load the CSV file
df = pd.read_csv('/h/sraza/news-media-bias-plus/classifiers/LLM/data/clean_data.csv')
texts = df['text_content'].tolist()
labels = df['text_label'].tolist()
# Convert labels to lower case for case-insensitive comparison
labels = [label.lower() for label in labels]
# Classify a few sample texts and display ground truth along with probabilities
for text, ground_truth in zip(texts[:5], labels[:5]): # Classify first 5 texts as an example
predicted_label, confidence, probabilities = classify_text(text)
print(f"Text: {text[:100]}...") # Print first 100 characters
print(f"Ground Truth: {ground_truth}")
print(f"Predicted class: {predicted_label}")
print(f"Confidence: {confidence:.2f}")
print(f"Probabilities: {probabilities}")
print("---")
- Downloads last month
- 7