from datasets import load_dataset, DatasetDict from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer from transformers import EarlyStoppingCallback, get_linear_schedule_with_warmup from transformers.integrations import TensorBoardCallback import torch import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load the dataset and take only 1000 samples logger.info("Loading dataset...") dataset = load_dataset("fddemarco/pushshift-reddit-comments", split="train") dataset = dataset.shuffle(seed=42) logger.info("Dataset loaded, shuffled, and truncated to 10,000 samples.") # Split the train dataset into train and test train_testvalid = dataset["train"].train_test_split(test_size=0.2, seed=42) test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=42) dataset = DatasetDict({ "train": train_testvalid["train"], "test": test_valid["test"], "validation": test_valid["train"] }) # Function to generate more formal text (placeholder - replace with actual implementation) def generate_formal_text(text): # Implement formal text generation here return text # Placeholder # Prepare the dataset def prepare_data(example): example["formal_text"] = generate_formal_text(example["body"]) # Changed from "text" to "body" return example logger.info("Preparing dataset...") processed_dataset = {split: data.map(prepare_data) for split, data in dataset.items()} logger.info("Dataset prepared.") # Tokenize the dataset model_name = "t5-base" tokenizer = AutoTokenizer.from_pretrained(model_name) def tokenize_function(examples): model_inputs = tokenizer(examples["formal_text"], max_length=256, truncation=True, padding="max_length") labels = tokenizer(examples["body"], max_length=256, truncation=True, padding="max_length") model_inputs["labels"] = labels["input_ids"] return model_inputs logger.info("Tokenizing dataset...") tokenized_dataset = {split: data.map(tokenize_function, batched=True) for split, data in processed_dataset.items()} logger.info("Dataset tokenized.") # Check available splits in the dataset available_splits = list(tokenized_dataset.keys()) logger.info(f"Available splits in the dataset: {available_splits}") # Set up the model and trainer logger.info("Setting up model and trainer...") model = T5ForConditionalGeneration.from_pretrained(model_name) training_args = Seq2SeqTrainingArguments( output_dir="./results", num_train_epochs=3, # Increase number of epochs per_device_train_batch_size=32, # Increase batch size if memory allows per_device_eval_batch_size=32, warmup_steps=500, weight_decay=0.01, logging_dir="./logs", logging_steps=100, evaluation_strategy="steps", eval_steps=1000, save_steps=1000, use_cpu=False, # Use GPU if available load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, fp16=True, # Enable mixed precision training if GPU supports it gradient_accumulation_steps=2, # Accumulate gradients to simulate larger batch sizes ) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=500, num_training_steps=len(tokenized_dataset["train"]) * training_args.num_train_epochs ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset.get("test"), tokenizer=tokenizer, optimizers=(optimizer, scheduler), callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TensorBoardCallback()] ) logger.info("Model and trainer set up.") # Train the model logger.info("Starting training...") trainer.train() logger.info("Training completed.") # Log final results logger.info("Evaluating model...") results = trainer.evaluate() logger.info(f"Final evaluation results: {results}") # Save the model and tokenizer to the Hugging Face Hub logger.info("Saving model and tokenizer to Hugging Face Hub...") model_name = "umut-bozdag/humanize_model" trainer.push_to_hub(model_name) tokenizer.push_to_hub(model_name) logger.info(f"Model and tokenizer saved successfully as '{model_name}'")