Steph974's picture
Update app.py
6c2c339 verified
raw
history blame
2.55 kB
# Importation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import gradio as gr
from gradio.components import Label
path = "./weights"
model = AutoModel.from_pretrained(path, trust_remote_code=True)
class CamembertClass(torch.nn.Module):
def __init__(self):
super(CamembertClass, self).__init__()
self.l1 = model
self.dropout = torch.nn.Dropout(0.1)
self.pre_classifier = torch.nn.Linear(1024, 1024)
self.classifier = torch.nn.Linear(1024, 3)
def forward(self, input_ids, attention_mask, token_type_ids):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.pre_classifier(pooler)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
return output
#model_gradio = CamembertClass()
path = "./pytorch_model.bin"
model = torch.load(path, map_location="cpu")
path_tokenizer = "./"
tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
model.eval() # Mettez votre modèle en mode évaluation
# Fonction d'inférence pour Gradio
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
# Extract necessary inputs for the model
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
token_type_ids = inputs.get('token_type_ids', None) # Some models do not use segment IDs
# Make prediction
with torch.no_grad():
# Directly use outputs if your model returns logits directly
logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# Convert logits to probabilities
probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
# Replace the following with your actual classes
classes = ['Negative Sentiment', 'Positive Sentiment']
return {classes[i]: float(probabilities[i]) for i in range(len(classes))}
# Création de l'interface Gradio
iface = gr.Interface(fn=predict,
inputs=gr.components.Textbox(placeholder="Enter your text here..."),
outputs=gr.components.Label(num_top_classes=2))
iface.launch(share=True)