Spaces:
Build error
Build error
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
# Load the pre-trained BERT model and tokenizer | |
model_name = "bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6) | |
# Load the train and test data | |
train_data = pd.read_csv("train.csv") | |
test_data = pd.read_csv("test.csv") | |
# Define the function to preprocess the text | |
def preprocess(text): | |
inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors="pt") | |
return inputs["input_ids"], inputs["attention_mask"] | |
# Preprocess the train and test data | |
X_train = train_data["comment_text"].tolist() | |
y_train = train_data[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]].values.tolist() | |
train_encodings = tokenizer(X_train, padding=True, truncation=True, max_length=128, return_tensors="pt") | |
train_dataset = torch.utils.data.TensorDataset(train_encodings["input_ids"], train_encodings["attention_mask"], torch.tensor(y_train)) | |
X_test = test_data["comment_text"].tolist() | |
test_encodings = tokenizer(X_test, padding=True, truncation=True, max_length=128, return_tensors="pt") | |
test_dataset = torch.utils.data.TensorDataset(test_encodings["input_ids"], test_encodings["attention_mask"]) | |
# Define the training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy="epoch", | |
num_train_epochs=3, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=64, | |
logging_dir='./logs', | |
) | |
# Define the trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=test_dataset, | |
) | |
# Train the model | |
trainer.train() | |
# Evaluate the model | |
eval_results = trainer.evaluate() | |
# Print the evaluation results | |
print(eval_results) |