Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer | |
from torch.utils.data import DataLoader | |
import traceback | |
dir_path = os.path.abspath('./') | |
os.environ["HF_HOME"] = dir_path | |
start_training = st.button("Train Model") | |
def tokenize_function(examples): | |
# Concatenate Instruction and Response | |
combined_texts = [instr + " " + resp for instr, resp in zip(examples["Instruction"], examples["Response"])] | |
# return tokenizer(combined_texts, padding="max_length", truncation=True) | |
tokenized_inputs = tokenizer(combined_texts, padding="max_length", truncation=True, max_length=512) | |
tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy() | |
return tokenized_inputs | |
if start_training: | |
st.write("Getting model and dataset ...") | |
# Load the dataset | |
dataset = load_dataset("viber1/indian-law-dataset", cache_dir=dir_path) | |
# Update this path based on where the tokenizer files are actually stored | |
tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load the model | |
model = AutoModelForCausalLM.from_pretrained('gpt2') | |
model.gradient_checkpointing_enable() | |
st.write("Training setup ...") | |
# Apply the tokenizer to the dataset | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
# Split the dataset manually into train and validation sets | |
split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1) | |
# Convert the dataset to PyTorch tensors | |
train_dataset = split_dataset["train"].with_format("torch") | |
eval_dataset = split_dataset["test"].with_format("torch") | |
# Create data loaders | |
# reduce batch size 8 to 1 | |
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True) | |
eval_dataloader = DataLoader(eval_dataset, batch_size=1, pin_memory=True) | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", | |
eval_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=1, | |
per_device_eval_batch_size=1, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
fp16=True, # Enable mixed precision | |
# save_total_limit=2, | |
logging_dir='./logs', # Set logging directory | |
logging_steps=10, # Log more frequently | |
gradient_checkpointing=True, # Enable gradient checkpointing | |
gradient_accumulation_steps=8 # Accumulate gradients over 8 | |
) | |
st.write("Training Started .....") | |
# Create the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
try: | |
trainer.train() | |
except Exception as e: | |
st.write(f"Error: {e}") | |
traceback.print_exc() | |
st.write("some error") | |
# Evaluate the model | |
st.write("Training Done ...") | |
results = trainer.evaluate() | |
st.write(results) | |