import os import streamlit as st from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from torch.utils.data import DataLoader import traceback dir_path = os.path.abspath('./') os.environ["HF_HOME"] = dir_path start_training = st.button("Train Model") def tokenize_function(examples): # Concatenate Instruction and Response combined_texts = [instr + " " + resp for instr, resp in zip(examples["Instruction"], examples["Response"])] # return tokenizer(combined_texts, padding="max_length", truncation=True) tokenized_inputs = tokenizer(combined_texts, padding="max_length", truncation=True, max_length=512) tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy() return tokenized_inputs if start_training: st.write("Getting model and dataset ...") # Load the dataset dataset = load_dataset("viber1/indian-law-dataset", cache_dir=dir_path) # Update this path based on where the tokenizer files are actually stored tokenizer = AutoTokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token # Load the model model = AutoModelForCausalLM.from_pretrained('gpt2') model.gradient_checkpointing_enable() st.write("Training setup ...") # Apply the tokenizer to the dataset tokenized_dataset = dataset.map(tokenize_function, batched=True) # Split the dataset manually into train and validation sets split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1) # Convert the dataset to PyTorch tensors train_dataset = split_dataset["train"].with_format("torch") eval_dataset = split_dataset["test"].with_format("torch") # Create data loaders # reduce batch size 8 to 1 train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True) eval_dataloader = DataLoader(eval_dataset, batch_size=1, pin_memory=True) # Define training arguments training_args = TrainingArguments( output_dir="./results", eval_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=1, per_device_eval_batch_size=1, num_train_epochs=3, weight_decay=0.01, fp16=True, # Enable mixed precision # save_total_limit=2, logging_dir='./logs', # Set logging directory logging_steps=10, # Log more frequently gradient_checkpointing=True, # Enable gradient checkpointing gradient_accumulation_steps=8 # Accumulate gradients over 8 ) st.write("Training Started .....") # Create the Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) try: trainer.train() except Exception as e: st.write(f"Error: {e}") traceback.print_exc() st.write("some error") # Evaluate the model st.write("Training Done ...") results = trainer.evaluate() st.write(results)