Spaces:
Sleeping
Sleeping
Added models and dataset for training
Browse files- app.py +87 -0
- models--gpt2/.no_exist/607a30d783dfa663caf39e06633721c8d4cfcd7e/added_tokens.json +0 -0
- models--gpt2/.no_exist/607a30d783dfa663caf39e06633721c8d4cfcd7e/special_tokens_map.json +0 -0
- models--gpt2/refs/main +1 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/config.json +31 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/generation_config.json +6 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/merges.txt +0 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/model.safetensors +3 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/tokenizer.json +0 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/tokenizer_config.json +1 -0
- models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/vocab.json +0 -0
- requirements.txt +6 -0
- viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/cache-22a773a22cb9ef7a.arrow +3 -0
- viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/cache-bef1f90cc85606a0.arrow +3 -0
- viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/cache-ce2a18a54b30a39e.arrow +3 -0
- viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/dataset_info.json +1 -0
- viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/indian-law-dataset-train.arrow +3 -0
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from datasets import load_dataset
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
dir_path = os.path.abspath('./')
|
9 |
+
os.environ["HF_HOME"] = dir_path
|
10 |
+
start_training = st.button("Train Model")
|
11 |
+
|
12 |
+
|
13 |
+
def tokenize_function(examples):
|
14 |
+
# Concatenate Instruction and Response
|
15 |
+
combined_texts = [instr + " " + resp for instr, resp in zip(examples["Instruction"], examples["Response"])]
|
16 |
+
# return tokenizer(combined_texts, padding="max_length", truncation=True)
|
17 |
+
tokenized_inputs = tokenizer(combined_texts, padding="max_length", truncation=True, max_length=512)
|
18 |
+
tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
|
19 |
+
return tokenized_inputs
|
20 |
+
|
21 |
+
|
22 |
+
if start_training:
|
23 |
+
st.write("Getting model and dataset ...")
|
24 |
+
# Load the dataset
|
25 |
+
dataset = load_dataset("viber1/indian-law-dataset", cache_dir=dir_path)
|
26 |
+
|
27 |
+
# Update this path based on where the tokenizer files are actually stored
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
29 |
+
tokenizer.pad_token = tokenizer.eos_token
|
30 |
+
# Load the model
|
31 |
+
model = AutoModelForCausalLM.from_pretrained('gpt2')
|
32 |
+
model.gradient_checkpointing_enable()
|
33 |
+
|
34 |
+
st.write("Training setup ...")
|
35 |
+
# Apply the tokenizer to the dataset
|
36 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
37 |
+
|
38 |
+
# Split the dataset manually into train and validation sets
|
39 |
+
split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1)
|
40 |
+
|
41 |
+
# Convert the dataset to PyTorch tensors
|
42 |
+
train_dataset = split_dataset["train"].with_format("torch")
|
43 |
+
eval_dataset = split_dataset["test"].with_format("torch")
|
44 |
+
|
45 |
+
# Create data loaders
|
46 |
+
# reduce batch size 8 to 1
|
47 |
+
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)
|
48 |
+
eval_dataloader = DataLoader(eval_dataset, batch_size=1, pin_memory=True)
|
49 |
+
|
50 |
+
# Define training arguments
|
51 |
+
training_args = TrainingArguments(
|
52 |
+
output_dir="./results",
|
53 |
+
eval_strategy="epoch",
|
54 |
+
learning_rate=2e-5,
|
55 |
+
per_device_train_batch_size=1,
|
56 |
+
per_device_eval_batch_size=1,
|
57 |
+
num_train_epochs=3,
|
58 |
+
weight_decay=0.01,
|
59 |
+
fp16=True, # Enable mixed precision
|
60 |
+
# save_total_limit=2,
|
61 |
+
logging_dir='./logs', # Set logging directory
|
62 |
+
logging_steps=10, # Log more frequently
|
63 |
+
gradient_checkpointing=True, # Enable gradient checkpointing
|
64 |
+
gradient_accumulation_steps=8 # Accumulate gradients over 8
|
65 |
+
)
|
66 |
+
|
67 |
+
st.write("Training Started .....")
|
68 |
+
|
69 |
+
# Create the Trainer
|
70 |
+
trainer = Trainer(
|
71 |
+
model=model,
|
72 |
+
args=training_args,
|
73 |
+
train_dataset=train_dataset,
|
74 |
+
eval_dataset=eval_dataset,
|
75 |
+
)
|
76 |
+
|
77 |
+
try:
|
78 |
+
trainer.train()
|
79 |
+
except Exception as e:
|
80 |
+
st.write(f"Error: {e}")
|
81 |
+
traceback.print_exc()
|
82 |
+
st.write("some error")
|
83 |
+
|
84 |
+
# Evaluate the model
|
85 |
+
st.write("Training Done ...")
|
86 |
+
results = trainer.evaluate()
|
87 |
+
st.write(results)
|
models--gpt2/.no_exist/607a30d783dfa663caf39e06633721c8d4cfcd7e/added_tokens.json
ADDED
File without changes
|
models--gpt2/.no_exist/607a30d783dfa663caf39e06633721c8d4cfcd7e/special_tokens_map.json
ADDED
File without changes
|
models--gpt2/refs/main
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
607a30d783dfa663caf39e06633721c8d4cfcd7e
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_function": "gelu_new",
|
3 |
+
"architectures": [
|
4 |
+
"GPT2LMHeadModel"
|
5 |
+
],
|
6 |
+
"attn_pdrop": 0.1,
|
7 |
+
"bos_token_id": 50256,
|
8 |
+
"embd_pdrop": 0.1,
|
9 |
+
"eos_token_id": 50256,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"layer_norm_epsilon": 1e-05,
|
12 |
+
"model_type": "gpt2",
|
13 |
+
"n_ctx": 1024,
|
14 |
+
"n_embd": 768,
|
15 |
+
"n_head": 12,
|
16 |
+
"n_layer": 12,
|
17 |
+
"n_positions": 1024,
|
18 |
+
"resid_pdrop": 0.1,
|
19 |
+
"summary_activation": null,
|
20 |
+
"summary_first_dropout": 0.1,
|
21 |
+
"summary_proj_to_labels": true,
|
22 |
+
"summary_type": "cls_index",
|
23 |
+
"summary_use_proj": true,
|
24 |
+
"task_specific_params": {
|
25 |
+
"text-generation": {
|
26 |
+
"do_sample": true,
|
27 |
+
"max_length": 50
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"vocab_size": 50257
|
31 |
+
}
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 50256,
|
3 |
+
"eos_token_id": 50256,
|
4 |
+
"transformers_version": "4.26.0.dev0",
|
5 |
+
"_from_model_config": true
|
6 |
+
}
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:248dfc3911869ec493c76e65bf2fcf7f615828b0254c12b473182f0f81d3a707
|
3 |
+
size 548105171
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"model_max_length": 1024}
|
models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
transformers==4.41.2
|
3 |
+
torch==2.3.1
|
4 |
+
datasets==2.20.0
|
5 |
+
huggingface_hub==0.23.2
|
6 |
+
accelerate==0.32.1
|
viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/cache-22a773a22cb9ef7a.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b900e8f715eeef7213691a1042362250e295b57ef37e5c76ddb07b8ab79d8e08
|
3 |
+
size 20424
|
viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/cache-bef1f90cc85606a0.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:135fff71f54f468e93884d1bf9e0d874eb71e7817c9fe5d7767072f9a27d74bd
|
3 |
+
size 180784
|
viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/cache-ce2a18a54b30a39e.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bb3b6c99393e9313b34a63e943841617ef514ffd198e6e20eb5307a1ac2d90b
|
3 |
+
size 177005624
|
viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/dataset_info.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"description": "", "citation": "", "homepage": "", "license": "", "features": {"Instruction": {"dtype": "string", "_type": "Value"}, "Response": {"dtype": "string", "_type": "Value"}}, "builder_name": "json", "dataset_name": "indian-law-dataset", "config_name": "default", "version": {"version_str": "0.0.0", "major": 0, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 12911240, "num_examples": 24607, "dataset_name": "indian-law-dataset"}}, "download_checksums": {"hf://datasets/viber1/indian-law-dataset@705c4e2c852380d1120f51121ac1ed020b4f743b/train.jsonl": {"num_bytes": 14408595, "checksum": null}}, "download_size": 14408595, "dataset_size": 12911240, "size_in_bytes": 27319835}
|
viber1___indian-law-dataset/default/0.0.0/705c4e2c852380d1120f51121ac1ed020b4f743b/indian-law-dataset-train.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4760527523f1b93bfe0bd6c9521aa0873060b437702d095af44e0d3141cad759
|
3 |
+
size 12919232
|