|
|
|
|
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import AutoModel, AutoTokenizer |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
import gradio as gr |
|
from gradio.components import Label |
|
|
|
path = "./weights" |
|
model = AutoModel.from_pretrained(path, trust_remote_code=True) |
|
class CamembertClass(torch.nn.Module): |
|
def __init__(self): |
|
super(CamembertClass, self).__init__() |
|
self.l1 = model |
|
self.dropout = torch.nn.Dropout(0.1) |
|
self.pre_classifier = torch.nn.Linear(1024, 1024) |
|
self.classifier = torch.nn.Linear(1024, 3) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
hidden_state = output_1[0] |
|
pooler = hidden_state[:, 0] |
|
pooler = self.pre_classifier(pooler) |
|
pooler = torch.nn.ReLU()(pooler) |
|
pooler = self.dropout(pooler) |
|
output = self.classifier(pooler) |
|
return output |
|
|
|
|
|
path = "./pytorch_model.bin" |
|
model = torch.load(path, map_location="cpu") |
|
path_tokenizer = "./" |
|
tokenizer = AutoTokenizer.from_pretrained(path_tokenizer) |
|
|
|
model.eval() |
|
|
|
|
|
def predict(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
|
|
|
input_ids = inputs['input_ids'] |
|
attention_mask = inputs['attention_mask'] |
|
token_type_ids = inputs.get('token_type_ids', None) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
|
|
|
|
|
|
probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0] |
|
|
|
classes = ['Negative Sentiment', 'Positive Sentiment'] |
|
return {classes[i]: float(probabilities[i]) for i in range(len(classes))} |
|
|
|
|
|
iface = gr.Interface(fn=predict, |
|
inputs=gr.components.Textbox(placeholder="Enter your text here..."), |
|
outputs=gr.components.Label(num_top_classes=2)) |
|
iface.launch(share=True) |
|
|