training-gpt-2 / app.py
vakodiya's picture
Added models and dataset for training
2032ac8
raw
history blame
3.05 kB
import os
import streamlit as st
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from torch.utils.data import DataLoader
import traceback
dir_path = os.path.abspath('./')
os.environ["HF_HOME"] = dir_path
start_training = st.button("Train Model")
def tokenize_function(examples):
# Concatenate Instruction and Response
combined_texts = [instr + " " + resp for instr, resp in zip(examples["Instruction"], examples["Response"])]
# return tokenizer(combined_texts, padding="max_length", truncation=True)
tokenized_inputs = tokenizer(combined_texts, padding="max_length", truncation=True, max_length=512)
tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
return tokenized_inputs
if start_training:
st.write("Getting model and dataset ...")
# Load the dataset
dataset = load_dataset("viber1/indian-law-dataset", cache_dir=dir_path)
# Update this path based on where the tokenizer files are actually stored
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# Load the model
model = AutoModelForCausalLM.from_pretrained('gpt2')
model.gradient_checkpointing_enable()
st.write("Training setup ...")
# Apply the tokenizer to the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Split the dataset manually into train and validation sets
split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1)
# Convert the dataset to PyTorch tensors
train_dataset = split_dataset["train"].with_format("torch")
eval_dataset = split_dataset["test"].with_format("torch")
# Create data loaders
# reduce batch size 8 to 1
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, pin_memory=True)
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
eval_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=3,
weight_decay=0.01,
fp16=True, # Enable mixed precision
# save_total_limit=2,
logging_dir='./logs', # Set logging directory
logging_steps=10, # Log more frequently
gradient_checkpointing=True, # Enable gradient checkpointing
gradient_accumulation_steps=8 # Accumulate gradients over 8
)
st.write("Training Started .....")
# Create the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
try:
trainer.train()
except Exception as e:
st.write(f"Error: {e}")
traceback.print_exc()
st.write("some error")
# Evaluate the model
st.write("Training Done ...")
results = trainer.evaluate()
st.write(results)