from model import CFM, UNetT, DiT, MMDiT, Trainer from model.utils import get_tokenizer from model.dataset import load_dataset # -------------------------- Dataset Settings --------------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 tokenizer = "pinyin" dataset_name = "Emilia_ZH_EN" # -------------------------- Training Settings -------------------------- # exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base learning_rate = 7.5e-5 batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200 batch_size_type = "frame" # "frame" or "sample" max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps max_grad_norm = 1. epochs = 11 # use linear decay, thus epochs control the slope num_warmup_updates = 20000 # warmup steps save_per_updates = 50000 # save checkpoint per steps last_per_steps = 5000 # save last checkpoint per steps # model params if exp_name == "F5TTS_Base": wandb_resume_id = None model_cls = DiT model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) elif exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) # ----------------------------------------------------------------------- # def main(): vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) mel_spec_kwargs = dict( target_sample_rate = target_sample_rate, n_mel_channels = n_mel_channels, hop_length = hop_length, ) e2tts = CFM( transformer = model_cls( **model_cfg, text_num_embeds = vocab_size, mel_dim = n_mel_channels ), mel_spec_kwargs = mel_spec_kwargs, vocab_char_map = vocab_char_map, ) trainer = Trainer( e2tts, epochs, learning_rate, num_warmup_updates = num_warmup_updates, save_per_updates = save_per_updates, checkpoint_path = f'ckpts/{exp_name}', batch_size = batch_size_per_gpu, batch_size_type = batch_size_type, max_samples = max_samples, grad_accumulation_steps = grad_accumulation_steps, max_grad_norm = max_grad_norm, wandb_project = "CFM-TTS", wandb_run_name = exp_name, wandb_resume_id = wandb_resume_id, last_per_steps = last_per_steps, ) train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) trainer.train(train_dataset, resumable_with_seed = 666 # seed for shuffling dataset ) if __name__ == '__main__': main()