humanizer_model / humanizer.py
lucidmorto's picture
feat: Update model and parameters for improved text humanization
0f051eb
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import EarlyStoppingCallback, get_linear_schedule_with_warmup
from transformers.integrations import TensorBoardCallback
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load the dataset and take only 1000 samples
logger.info("Loading dataset...")
dataset = load_dataset("fddemarco/pushshift-reddit-comments", split="train")
dataset = dataset.shuffle(seed=42)
logger.info("Dataset loaded, shuffled, and truncated to 10,000 samples.")
# Split the train dataset into train and test
train_testvalid = dataset["train"].train_test_split(test_size=0.2, seed=42)
test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=42)
dataset = DatasetDict({
"train": train_testvalid["train"],
"test": test_valid["test"],
"validation": test_valid["train"]
})
# Prepare the dataset
def prepare_data(example):
return {"input_text": example["body"], "target_text": example["body"]}
logger.info("Preparing dataset...")
processed_dataset = {split: data.map(prepare_data) for split, data in dataset.items()}
logger.info("Dataset prepared.")
# Tokenize the dataset
model_name = "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
model_inputs = tokenizer(examples["input_text"], max_length=256, truncation=True, padding="max_length")
labels = tokenizer(examples["target_text"], max_length=256, truncation=True, padding="max_length")
model_inputs["labels"] = labels["input_ids"]
return model_inputs
logger.info("Tokenizing dataset...")
tokenized_dataset = {split: data.map(tokenize_function, batched=True) for split, data in processed_dataset.items()}
logger.info("Dataset tokenized.")
# Check available splits in the dataset
available_splits = list(tokenized_dataset.keys())
logger.info(f"Available splits in the dataset: {available_splits}")
# Set up the model and trainer
logger.info("Setting up model and trainer...")
model = T5ForConditionalGeneration.from_pretrained(model_name)
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
num_train_epochs=3, # Reduced from 5 to 3
per_device_train_batch_size=32, # Increased from 16 to 32
per_device_eval_batch_size=32,
warmup_steps=2000, # Increased from 1000 to 2000
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500, # Increased from 100 to 500
evaluation_strategy="steps",
eval_steps=2000, # Increased from 500 to 2000
save_steps=2000, # Increased from 500 to 2000
use_cpu=False,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
gradient_accumulation_steps=2, # Reduced from 4 to 2
predict_with_generate=True,
generation_max_length=256,
generation_num_beams=4,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5) # Slightly lower learning rate
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=500,
num_training_steps=len(tokenized_dataset["train"]) * training_args.num_train_epochs
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset.get("test"),
tokenizer=tokenizer,
optimizers=(optimizer, scheduler),
callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TensorBoardCallback()]
)
logger.info("Model and trainer set up.")
# Train the model
logger.info("Starting training...")
trainer.train()
logger.info("Training completed.")
# Log final results
logger.info("Evaluating model...")
results = trainer.evaluate()
logger.info(f"Final evaluation results: {results}")
# Save the model and tokenizer to the Hugging Face Hub
logger.info("Saving model and tokenizer to Hugging Face Hub...")
model_name = "umut-bozdag/humanize_model"
trainer.push_to_hub(model_name)
tokenizer.push_to_hub(model_name)
logger.info(f"Model and tokenizer saved successfully as '{model_name}'")