|
---
|
|
license: cc-by-nc-nd-4.0
|
|
language:
|
|
- en
|
|
|
|
model-index:
|
|
- name: roberta-large Image Prompt Classifier
|
|
results:
|
|
- task:
|
|
type: text-classification
|
|
dataset:
|
|
name: nsfw-text-detection
|
|
type: custom
|
|
metrics:
|
|
- name: Accuracy
|
|
type: self-reported
|
|
value: 93%
|
|
- name: Precision
|
|
type: self-reported
|
|
value: 88%
|
|
- name: Recall
|
|
type: self-reported
|
|
value: 90%
|
|
---
|
|
|
|
# roberta-large Image Prompt Classifier
|
|
|
|
## Model Overview
|
|
|
|
This model is a fine-tuned version of `roberta-large` designed specifically for classifying image generation prompts into three distinct categories: SAFE, QUESTIONABLE, and UNSAFE. Leveraging the robust capabilities of the `roberta-large` architecture, this model ensures high accuracy and reliability in identifying the nature of prompts used for generating images.
|
|
|
|
## Model Details
|
|
|
|
- **Model Name:** roberta-large Image Prompt Classifier
|
|
- **Base Model:** [roberta-large](https://huggingface.co/roberta-large)
|
|
- **Fine-tuned By:** Michał Młodawski
|
|
- **Categories:**
|
|
- `0`: SAFE
|
|
- `1`: QUESTIONABLE
|
|
- `2`: UNSAFE
|
|
|
|
## Use Cases
|
|
|
|
This model is particularly useful for platforms and applications involving AI-generated content, where it is crucial to filter and classify prompts to maintain content safety and appropriateness. Some potential applications include:
|
|
|
|
- **Content Moderation:** Automatically classify and filter prompts to prevent the generation of inappropriate or harmful images.
|
|
- **User Safety:** Enhance user experience by ensuring that generated content adheres to safety guidelines.
|
|
- **Compliance:** Help platforms comply with regulatory requirements by identifying and flagging potentially unsafe prompts.
|
|
|
|
## How It Works
|
|
|
|
The model takes an input prompt and classifies it into one of three categories:
|
|
|
|
1. **SAFE:** Prompts that are deemed appropriate and free from harmful content.
|
|
2. **QUESTIONABLE:** Prompts that may require further review due to potential ambiguity or slight risk.
|
|
3. **UNSAFE:** Prompts that are likely to generate inappropriate or harmful content.
|
|
|
|
The classification is based on the semantic understanding and contextual analysis provided by the `roberta-large` architecture, fine-tuned on a curated dataset tailored for this specific task.
|
|
|
|
## Performance
|
|
|
|
Extensive testing and validation have been conducted to ensure the model's performance meets high standards of accuracy and reliability. Key metrics include:
|
|
|
|
- **Accuracy:** 93%
|
|
- **Precision:** 88%
|
|
- **Recall:** 90%
|
|
|
|
These metrics highlight the model's capability to effectively distinguish between different categories of prompts with high confidence.
|
|
|
|
## Getting Started
|
|
|
|
To start using the roberta-large Image Prompt Classifier, you can easily integrate it into your existing workflows with the following steps:
|
|
|
|
### Installation
|
|
|
|
pip install transformers==4.37.2
|
|
pip install torch==2.3.1
|
|
|
|
### Usage
|
|
|
|
import os
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
import torch
|
|
|
|
|
|
model_repo = "MichalMlodawski/nsfw-text-detection-large"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
|
|
|
def classify_text(text):
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
logits = outputs.logits
|
|
predicted_class = torch.argmax(logits, dim=1).item()
|
|
|
|
return predicted_class
|
|
|
|
if __name__ == "__main__":
|
|
examples = [
|
|
{"text": "wild west cowboy caravan through Arizona Desert , Men Woman and children, horse drawn carriages", "expected_class": 0},
|
|
{"text": "A person with questionable intentions in a dimly lit alley.", "expected_class": 1},
|
|
{"text": "An inappropriate scene involving explicit content.", "expected_class": 2}
|
|
]
|
|
|
|
for example in examples:
|
|
text = example["text"]
|
|
expected_class = example["expected_class"]
|
|
predicted_class = classify_text(text)
|
|
print(f"Prompt: {text}")
|
|
print(f"Predicted class: {predicted_class}")
|
|
print(f"Expected class: {expected_class}")
|
|
print("---")
|
|
|
|
### Disclaimer
|
|
|
|
The dataset includes actual prompts for image generation that could be perceived as abusive, offensive, or obscene. Furthermore, the examples and data might contain unfavorable information about certain businesses. We merely collect this data and do not assume any legal responsibility.
|
|
Please note: A portion of the data was created using Large Language Models (LLM).
|
|
|