import os import tempfile import gradio as gr import torch from zerorvc import RVCTrainer, pretrained_checkpoints, SynthesizerTrnMs768NSFsid from zerorvc.trainer import TrainingCheckpoint from datasets import load_from_disk from huggingface_hub import snapshot_download from .zero import zero from .model import accelerator, device from .constants import BATCH_SIZE, ROOT_EXP_DIR, TRAINING_EPOCHS @zero(duration=240) def train_model(exp_dir: str, progress=gr.Progress()): dataset = os.path.join(exp_dir, "dataset") if not os.path.exists(dataset): raise gr.Error("Dataset not found. Please prepare the dataset first.") ds = load_from_disk(dataset) checkpoint_dir = os.path.join(exp_dir, "checkpoints") trainer = RVCTrainer(checkpoint_dir) resume_from = trainer.latest_checkpoint() if resume_from is None: resume_from = pretrained_checkpoints() gr.Info(f"Starting training from pretrained checkpoints.") else: gr.Info(f"Resuming training from {resume_from}") tqdm = progress.tqdm( trainer.train( dataset=ds["train"], resume_from=resume_from, batch_size=BATCH_SIZE, epochs=TRAINING_EPOCHS, accelerator=accelerator, ), total=TRAINING_EPOCHS, unit="epochs", desc="Training", ) for ckpt in tqdm: info = f"Epoch: {ckpt.epoch} loss: (gen: {ckpt.loss_gen:.4f}, fm: {ckpt.loss_fm:.4f}, mel: {ckpt.loss_mel:.4f}, kl: {ckpt.loss_kl:.4f}, disc: {ckpt.loss_disc:.4f})" print(info) latest: TrainingCheckpoint = ckpt latest.save(trainer.checkpoint_dir) latest.G.save_pretrained(trainer.checkpoint_dir) result = f"{TRAINING_EPOCHS} epochs trained. Latest loss: (gen: {latest.loss_gen:.4f}, fm: {latest.loss_fm:.4f}, mel: {latest.loss_mel:.4f}, kl: {latest.loss_kl:.4f}, disc: {latest.loss_disc:.4f})" del trainer if device.type == "cuda": torch.cuda.empty_cache() return result def upload_model(exp_dir: str, repo: str, hf_token: str): checkpoint_dir = os.path.join(exp_dir, "checkpoints") if not os.path.exists(checkpoint_dir): raise gr.Error("Model not found") gr.Info("Uploading model") model = SynthesizerTrnMs768NSFsid.from_pretrained(checkpoint_dir) model.push_to_hub(repo, token=hf_token, private=True) gr.Info("Model uploaded successfully") def upload_checkpoints(exp_dir: str, repo: str, hf_token: str): checkpoint_dir = os.path.join(exp_dir, "checkpoints") if not os.path.exists(checkpoint_dir): raise gr.Error("Checkpoints not found") gr.Info("Uploading checkpoints") trainer = RVCTrainer(checkpoint_dir) trainer.push_to_hub(repo, token=hf_token, private=True) gr.Info("Checkpoints uploaded successfully") def fetch_model(exp_dir: str, repo: str, hf_token: str): if not exp_dir: exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) checkpoint_dir = os.path.join(exp_dir, "checkpoints") gr.Info("Fetching model") files = ["README.md", "config.json", "model.safetensors"] snapshot_download( repo, token=hf_token, local_dir=checkpoint_dir, allow_patterns=files ) gr.Info("Model fetched successfully") return exp_dir def fetch_checkpoints(exp_dir: str, repo: str, hf_token: str): if not exp_dir: exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) checkpoint_dir = os.path.join(exp_dir, "checkpoints") gr.Info("Fetching checkpoints") snapshot_download(repo, token=hf_token, local_dir=checkpoint_dir) gr.Info("Checkpoints fetched successfully") return exp_dir class TrainTab: def __init__(self): pass def ui(self): gr.Markdown("# Training") gr.Markdown( "You can start training the model by clicking the button below. " f"Each time you click the button, the model will train for {TRAINING_EPOCHS} epochs, which takes about 3 minutes on ZeroGPU (A100). " ) with gr.Row(): self.train_btn = gr.Button(value="Train", variant="primary") self.result = gr.Textbox(label="Training Result", lines=3) gr.Markdown("## Sync Model and Checkpoints with Hugging Face") gr.Markdown( "You can upload the trained model and checkpoints to Hugging Face for sharing or further training." ) self.repo = gr.Textbox(label="Repository ID", placeholder="username/repo") with gr.Row(): self.upload_model_btn = gr.Button(value="Upload Model", variant="primary") self.upload_checkpoints_btn = gr.Button( value="Upload Checkpoints", variant="primary" ) with gr.Row(): self.fetch_mode_btn = gr.Button(value="Fetch Model", variant="primary") self.fetch_checkpoints_btn = gr.Button( value="Fetch Checkpoints", variant="primary" ) def build(self, exp_dir: gr.Textbox, hf_token: gr.Textbox): self.train_btn.click( fn=train_model, inputs=[exp_dir], outputs=[self.result], ) self.upload_model_btn.click( fn=upload_model, inputs=[exp_dir, self.repo, hf_token], ) self.upload_checkpoints_btn.click( fn=upload_checkpoints, inputs=[exp_dir, self.repo, hf_token], ) self.fetch_mode_btn.click( fn=fetch_model, inputs=[exp_dir, self.repo, hf_token], outputs=[exp_dir], ) self.fetch_checkpoints_btn.click( fn=fetch_checkpoints, inputs=[exp_dir, self.repo, hf_token], outputs=[exp_dir], )