Better performance on short context lengths
Nice seeing the work you're doing with my dataset!
I recommend chunking the chapters to fit within a reasonable context size (2048-4096). Using RoPE had poor results from my testing, and also took more VRAM.
from transformers import AutoTokenizer
import jsonlines
import os
tokenizer = AutoTokenizer.from_pretrained("NilanE/tinyllama-relora-merge")
max_seq_len = 2048 # max context length
prompt = "Translate this from Japanese to English:\n### JAPANESE:\n\n### ENGLISH:\n</s>" # insert SFT prompt to add to token count
input_file_path = "dataset.jsonl"
output_file_path = input_file_path.split('.')[0] + "-chunked." + input_file_path.split('.')[1]
promptTokens = len(tokenizer.tokenize(prompt))
#tolerance
max_seq_len -= 10
skippedDocs = 0
if os.path.exists(output_file_path):
os.remove(output_file_path)
with jsonlines.open(input_file_path) as reader, jsonlines.open(output_file_path, 'a') as writer:
for entry in reader:
src_lines = entry['src'].strip().split('\n')
trg_lines = entry['trg'].strip().split('\n')
out_src = []
out_trg = []
tokenCount = 0
lastTokenCount = 0
longLines = 0
try:
for x in range(len(src_lines)):
out_src.append(src_lines[x])
out_trg.append(trg_lines[x])
out_src_string = "\n".join(out_src)
trg_src_string = "\n".join(out_trg)
tokenCount = len(tokenizer.tokenize(out_src_string.strip() + trg_src_string.strip())) + promptTokens
if tokenCount-lastTokenCount < max_seq_len-1: # avoid lines > max line length
if tokenCount > max_seq_len-1:
src_end = out_src.pop()
trg_end = out_trg.pop()
out_src_string = "\n".join(out_src)
trg_src_string = "\n".join(out_trg)
data = {
'src' : out_src_string.strip(),
'trg' : trg_src_string.strip()
}
writer.write(data)
out_src = [src_end]
out_trg = [trg_end]
elif x+1 == len(src_lines): #and len(out_src) > 2:
data = {
'src' : out_src_string.strip(),
'trg' : trg_src_string.strip()
}
writer.write(data)
else:
# remove offending line > max_seq_len
out_src.pop()
out_trg.pop()
out_src_string = "\n".join(out_src)
trg_src_string = "\n".join(out_trg)
tokenCount = len(tokenizer.tokenize(prompt + out_src_string.strip() + trg_src_string.strip()))
longLines += 1
lastTokenCount = tokenCount
except:
skippedDocs += 1
print(f"LINES LONGER THAN MAX SEQUENCE LENTH: {longLines}")
print(f"SKIPPED DOCS: {skippedDocs}")
Here's the script I use for chunking my dataset. I actually tested a 7B model (augmxnt/shisa-gamma-7b-v1) on a subset of the dataset a while ago and had amazing result, even without much training.
I also filtered out the partial chapter titles in some entries, currently pushing the fixed dataset to the hub!
Yeah I'll try that script with the newer version of your dataset, and I did train it first with 4k context as well and that did seem to work better than training it with longer context. I might try other Japanese models but I chose this base model since it had pretty good scores in benchmarks.