|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from argparse import ArgumentParser |
|
from pytorch_lightning import Trainer |
|
|
|
print("########## work in progress ##########") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = ArgumentParser() |
|
|
|
parser.add_argument("--load_model", default="", type=str) |
|
parser.add_argument("--wandb", default="", type=str) |
|
parser.add_argument("--proj_dir", default="out", type=str) |
|
parser.add_argument("--random_seed", default="-1", type=int) |
|
|
|
parser.add_argument("--data_file", default="", type=str) |
|
parser.add_argument("--data_type", default="utf-8", type=str) |
|
parser.add_argument("--vocab_size", default=0, type=int) |
|
parser.add_argument("--vocab_size_delta", default=0, type=int) |
|
|
|
parser.add_argument("--ctx_len", default=1024, type=int) |
|
parser.add_argument("--epoch_steps", default=1000, type=int) |
|
parser.add_argument("--epoch_count", default=500, type=int) |
|
parser.add_argument("--epoch_begin", default=0, type=int) |
|
parser.add_argument("--epoch_save", default=5, type=int) |
|
|
|
parser.add_argument("--micro_bsz", default=12, type=int) |
|
parser.add_argument("--n_layer", default=6, type=int) |
|
parser.add_argument("--n_embd", default=512, type=int) |
|
parser.add_argument("--pre_ffn", default=0, type=int) |
|
parser.add_argument("--head_qk", default=0, type=int) |
|
parser.add_argument("--tiny_att_dim", default=0, type=int) |
|
parser.add_argument("--tiny_att_layer", default=-999, type=int) |
|
|
|
parser.add_argument("--lr_init", default=6e-4, type=float) |
|
parser.add_argument("--lr_final", default=1e-5, type=float) |
|
parser.add_argument("--warmup_steps", default=0, type=int) |
|
parser.add_argument("--beta1", default=0.9, type=float) |
|
parser.add_argument("--beta2", default=0.99, type=float) |
|
parser.add_argument("--adam_eps", default=1e-8, type=float) |
|
|
|
parser.add_argument("--grad_cp", default=0, type=int) |
|
parser.add_argument("--my_pile_stage", default=0, type=int) |
|
parser.add_argument("--my_pile_shift", default=-1, type=int) |
|
parser.add_argument("--my_pile_edecay", default=0, type=int) |
|
parser.add_argument("--layerwise_lr", default=1, type=int) |
|
parser.add_argument("--ds_bucket_mb", default=200, type=int) |
|
|
|
|
|
parser.add_argument("--my_img_version", default=0, type=str) |
|
parser.add_argument("--my_img_size", default=0, type=int) |
|
parser.add_argument("--my_img_bit", default=0, type=int) |
|
parser.add_argument("--my_img_clip", default='x', type=str) |
|
parser.add_argument("--my_img_clip_scale", default=1, type=float) |
|
parser.add_argument("--my_img_l1_scale", default=0, type=float) |
|
parser.add_argument("--my_img_encoder", default='x', type=str) |
|
|
|
parser.add_argument("--my_sample_len", default=0, type=int) |
|
parser.add_argument("--my_ffn_shift", default=1, type=int) |
|
parser.add_argument("--my_att_shift", default=1, type=int) |
|
parser.add_argument("--my_pos_emb", default=0, type=int) |
|
parser.add_argument("--load_partial", default=0, type=int) |
|
parser.add_argument("--magic_prime", default=0, type=int) |
|
parser.add_argument("--my_qa_mask", default=0, type=int) |
|
parser.add_argument("--my_testing", default=0, type=int) |
|
|
|
parser = Trainer.add_argparse_args(parser) |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
import os, warnings, math, datetime, sys, time |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import deepspeed |
|
import pytorch_lightning as pl |
|
from pytorch_lightning import seed_everything |
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only |
|
|
|
if args.random_seed >= 0: |
|
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) |
|
seed_everything(args.random_seed) |
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200) |
|
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") |
|
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") |
|
|
|
|
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") |
|
args.enable_checkpointing = False |
|
args.replace_sampler_ddp = False |
|
args.logger = False |
|
args.gradient_clip_val = 1.0 |
|
args.num_sanity_val_steps = 0 |
|
args.check_val_every_n_epoch = int(1e20) |
|
args.log_every_n_steps = int(1e20) |
|
args.max_epochs = -1 |
|
args.betas = (args.beta1, args.beta2) |
|
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz |
|
os.environ["RWKV_T_MAX"] = str(args.ctx_len) |
|
|
|
if args.data_type == "wds_img": |
|
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" |
|
args.proj_dir = f"{args.proj_dir}-{args.run_name}" |
|
else: |
|
args.run_name = f"{args.vocab_size}+{args.vocab_size_delta} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" |
|
if not os.path.exists(args.proj_dir): |
|
os.makedirs(args.proj_dir) |
|
|
|
if args.my_pile_stage > 0: |
|
magic_prime_bak = args.magic_prime |
|
if args.ctx_len == 1024: |
|
args.magic_prime = 324331313 |
|
args.epoch_count = 8043 |
|
elif args.ctx_len == 2048: |
|
args.magic_prime = 162165671 |
|
args.epoch_count = 4021 |
|
elif args.ctx_len == 4096: |
|
args.magic_prime = 81082817 |
|
args.epoch_count = 2010 |
|
if args.my_pile_shift < 0: |
|
if args.ctx_len == 1024: |
|
args.my_pile_shift = 0 |
|
elif args.ctx_len == 2048: |
|
args.my_pile_shift = 512 |
|
elif args.ctx_len == 4096: |
|
args.my_pile_shift = 768 |
|
|
|
if magic_prime_bak > 0: |
|
args.magic_prime = magic_prime_bak |
|
|
|
args.epoch_steps = 40320 // args.real_bsz |
|
assert args.epoch_steps * args.real_bsz == 40320 |
|
if args.my_pile_stage == 2: |
|
assert args.lr_final == args.lr_init |
|
if args.my_pile_stage >= 2: |
|
list_p = [] |
|
for p in os.listdir(args.proj_dir): |
|
if p.startswith("rwkv") and p.endswith(".pth"): |
|
p = ((p.split("-"))[1].split("."))[0] |
|
if p == "init": |
|
p = -1 |
|
else: |
|
p = int(p) |
|
list_p += [p] |
|
list_p.sort() |
|
max_p = list_p[-1] |
|
if len(list_p) > 1: |
|
args.my_pile_prev_p = list_p[-2] |
|
if max_p == -1: |
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" |
|
else: |
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" |
|
if args.my_pile_stage == 2: |
|
args.warmup_steps = 10 |
|
else: |
|
args.warmup_steps = 30 |
|
args.epoch_begin = max_p + 1 |
|
|
|
samples_per_epoch = args.epoch_steps * args.real_bsz |
|
tokens_per_epoch = samples_per_epoch * args.ctx_len |
|
rank_zero_info( |
|
f""" |
|
############################################################################ |
|
# |
|
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} |
|
# |
|
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} |
|
# |
|
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch |
|
# |
|
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens |
|
# |
|
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len |
|
# |
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} |
|
# |
|
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer |
|
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) |
|
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer |
|
# |
|
############################################################################ |
|
""" |
|
) |
|
rank_zero_info(str(vars(args)) + "\n") |
|
|
|
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] |
|
|
|
if args.lr_final == 0 or args.lr_init == 0: |
|
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") |
|
|
|
assert args.precision in ["fp32", "tf32", "fp16", "bf16"] |
|
os.environ["RWKV_FLOAT_MODE"] = args.precision |
|
if args.precision == "fp32": |
|
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") |
|
if args.precision == "fp16": |
|
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") |
|
|
|
os.environ["RWKV_JIT_ON"] = "1" |
|
if "deepspeed_stage_3" in args.strategy: |
|
os.environ["RWKV_JIT_ON"] = "0" |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.enabled = True |
|
if args.precision == "fp32": |
|
torch.backends.cudnn.allow_tf32 = False |
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
else: |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
if "32" in args.precision: |
|
args.precision = 32 |
|
elif args.precision == "fp16": |
|
args.precision = 16 |
|
else: |
|
args.precision = "bf16" |
|
|
|
|
|
|
|
from src.trainer import train_callback, generate_init_weight |
|
from src.dataset import MyDataset |
|
|
|
train_data = MyDataset(args) |
|
args.vocab_size = train_data.vocab_size |
|
|
|
if args.data_type == 'wds_img': |
|
from src.model_img import RWKV_IMG |
|
model = RWKV_IMG(args) |
|
else: |
|
from src.model import RWKV |
|
model = RWKV(args) |
|
|
|
if len(args.load_model) == 0 or args.my_pile_stage == 1: |
|
init_weight_name = f"{args.proj_dir}/rwkv-init.pth" |
|
generate_init_weight(model, init_weight_name) |
|
args.load_model = init_weight_name |
|
|
|
print(f"########## Loading {args.load_model}... ##########") |
|
try: |
|
load_dict = torch.load(args.load_model, map_location="cpu") |
|
except: |
|
print(f"Bad checkpoint {args.load_model}") |
|
if args.my_pile_stage >= 2: |
|
max_p = args.my_pile_prev_p |
|
if max_p == -1: |
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth" |
|
else: |
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" |
|
args.epoch_begin = max_p + 1 |
|
print(f"Trying {args.load_model}") |
|
load_dict = torch.load(args.load_model, map_location="cpu") |
|
|
|
if args.load_partial == 1: |
|
load_keys = load_dict.keys() |
|
for k in model.state_dict(): |
|
if k not in load_keys: |
|
load_dict[k] = model.state_dict()[k] |
|
model.load_state_dict(load_dict) |
|
if args.vocab_size_delta > 0: |
|
|
|
model.resize_emb(args.vocab_size + args.vocab_size_delta) |
|
args.vocab_size = args.vocab_size + args.vocab_size_delta |
|
|
|
trainer = Trainer.from_argparse_args( |
|
args, |
|
callbacks=[train_callback(args)], |
|
) |
|
if "deepspeed" in args.strategy: |
|
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 |
|
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 |
|
|
|
|
|
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) |
|
|
|
trainer.fit(model, data_loader) |
|
|