import torch import gradio as gr from sentence_transformers import SentenceTransformer from safetensors.torch import load_file import torch.nn as nn # Define the model class (same as in the training script) class Magical1Sun(nn.Module): def __init__(self, num_classes, dropout_rate=0.1): super(Magical1Sun, self).__init__() self.sentence_transformer = SentenceTransformer('all-MiniLM-L12-v2') self.dropout = nn.Dropout(dropout_rate) self.classifier = nn.Sequential( nn.Linear(384, 256), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(256, num_classes) ) def forward(self, text): embeddings = self.sentence_transformer.encode(text, convert_to_tensor=True) embeddings = self.dropout(embeddings) return self.classifier(embeddings) # Load the trained model def load_model(model_path): model = Magical1Sun(num_classes=2) state_dict = load_file(model_path) model.load_state_dict(state_dict) model.eval() return model # Prediction function def predict(text): with torch.no_grad(): output = model(text) probabilities = torch.softmax(output, dim=0) positive_prob = probabilities[1].item() negative_prob = probabilities[0].item() prediction = "Positive" if positive_prob > negative_prob else "Negative" confidence = max(positive_prob, negative_prob) return { "Prediction": prediction, "Confidence": f"{confidence:.2%}", "Positive Probability": f"{positive_prob:.2%}", "Negative Probability": f"{negative_prob:.2%}" } # Load the model (make sure to replace 'path_to_your_model.safetensors' with the actual path) model = load_model('magical_1_sun.safetensors') # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=3, placeholder="Enter text to classify..."), outputs=[ gr.Label(num_top_classes=1, label="Prediction"), gr.Label(label="Confidence"), gr.Label(label="Positive Probability"), gr.Label(label="Negative Probability") ], title="Magical-1 Sun Text Classification", description="Enter a text to classify it as positive or negative.", examples=[ ["I love this product! It's amazing!"], ["This is terrible. Worst purchase ever."], ["Great experience overall. Would buy again."], ["Never buying again. Complete waste of money."], ["Highly recommended! You won't regret it."] ] ) # Launch the app if __name__ == "__main__": iface.launch()