File size: 4,278 Bytes
16a1d49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import transformers
tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
# Load the model and tokenizer
# model = transformers.AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
# tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
# Define a function to split a text into segments of 512 tokens
def split_text(text):
# Tokenize the text
tokens = tokenizer.tokenize(text)
# Initialize an empty list for segments
segments = []
# Initialize an empty list for current segment
current_segment = []
# Initialize a counter for tokens
token_count = 0
# Loop through the tokens
for token in tokens:
# Add the token to the current segment
current_segment.append(token)
# Increment the token count
token_count += 1
# If the token count reaches 512 or the end of the text, add the current segment to the segments list
if token_count == 512 or token == tokens[-1]:
# Convert the current segment to a string and add it to the segments list
segments.append(tokenizer.convert_tokens_to_string(current_segment))
# Reset the current segment and the token count
current_segment = []
token_count = 0
# Return the segments list
return segments
# a function that classifies text
def classify_text(text):
# Define labels
labels = ["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"]
# Split text into segments using split_text
segments = split_text(text)
# Initialize empty list for predictions
predictions = []
# Move device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Loop through segments, process, and store predictions
for segment in segments:
inputs = tokenizer([segment], padding=True, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
# Extract predictions for each segment
probs, preds = extract_predictions(outputs) # Define this function based on your model's output
# Append predictions for this segment
predictions.append({
"segment_text": segment,
"label": preds[0], # Assuming single label prediction
"probability": probs[preds[0]] # Access probability for the predicted label
})
# Define a function to extract predictions from model output (adjust as needed)
def extract_predictions(outputs):
# Assuming outputs contain logits and labels (adapt based on your model's output format)
logits = outputs.logits
probs = logits.softmax(dim=1)
preds = torch.argmax(probs, dim=1)
return probs, preds # Return all probabilities and predicted labels
# def classify_text(text):
# """
# This function preprocesses, feeds text to the model, and outputs the predicted class.
# """
# inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits # Access logits instead of pipeline output
# predictions = torch.argmax(logits, dim=-1) # Apply argmax for prediction
# return model.config.id2label[predictions.item()] # Map index to class label
interface = gr.Interface(
fn=classify_text,
inputs="text",
outputs="text",
title="Text Classification Demo",
description="Enter some text, and the model will classify it.",
#choices=["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"] # Adjust class names
)
#interface.launch()
|