ZeroRVC / app /train.py
github-actions[bot]
Sync to HuggingFace Spaces
f80c5ec
raw
history blame
No virus
5.73 kB
import os
import tempfile
import gradio as gr
import torch
from zerorvc import RVCTrainer, pretrained_checkpoints, SynthesizerTrnMs768NSFsid
from zerorvc.trainer import TrainingCheckpoint
from datasets import load_from_disk
from huggingface_hub import snapshot_download
from .zero import zero
from .model import accelerator, device
from .constants import BATCH_SIZE, ROOT_EXP_DIR, TRAINING_EPOCHS
@zero(duration=240)
def train_model(exp_dir: str, progress=gr.Progress()):
dataset = os.path.join(exp_dir, "dataset")
if not os.path.exists(dataset):
raise gr.Error("Dataset not found. Please prepare the dataset first.")
ds = load_from_disk(dataset)
checkpoint_dir = os.path.join(exp_dir, "checkpoints")
trainer = RVCTrainer(checkpoint_dir)
resume_from = trainer.latest_checkpoint()
if resume_from is None:
resume_from = pretrained_checkpoints()
gr.Info(f"Starting training from pretrained checkpoints.")
else:
gr.Info(f"Resuming training from {resume_from}")
tqdm = progress.tqdm(
trainer.train(
dataset=ds["train"],
resume_from=resume_from,
batch_size=BATCH_SIZE,
epochs=TRAINING_EPOCHS,
accelerator=accelerator,
),
total=TRAINING_EPOCHS,
unit="epochs",
desc="Training",
)
for ckpt in tqdm:
info = f"Epoch: {ckpt.epoch} loss: (gen: {ckpt.loss_gen:.4f}, fm: {ckpt.loss_fm:.4f}, mel: {ckpt.loss_mel:.4f}, kl: {ckpt.loss_kl:.4f}, disc: {ckpt.loss_disc:.4f})"
print(info)
latest: TrainingCheckpoint = ckpt
latest.save(trainer.checkpoint_dir)
latest.G.save_pretrained(trainer.checkpoint_dir)
result = f"{TRAINING_EPOCHS} epochs trained. Latest loss: (gen: {latest.loss_gen:.4f}, fm: {latest.loss_fm:.4f}, mel: {latest.loss_mel:.4f}, kl: {latest.loss_kl:.4f}, disc: {latest.loss_disc:.4f})"
del trainer
if device.type == "cuda":
torch.cuda.empty_cache()
return result
def upload_model(exp_dir: str, repo: str, hf_token: str):
checkpoint_dir = os.path.join(exp_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
raise gr.Error("Model not found")
gr.Info("Uploading model")
model = SynthesizerTrnMs768NSFsid.from_pretrained(checkpoint_dir)
model.push_to_hub(repo, token=hf_token, private=True)
gr.Info("Model uploaded successfully")
def upload_checkpoints(exp_dir: str, repo: str, hf_token: str):
checkpoint_dir = os.path.join(exp_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
raise gr.Error("Checkpoints not found")
gr.Info("Uploading checkpoints")
trainer = RVCTrainer(checkpoint_dir)
trainer.push_to_hub(repo, token=hf_token, private=True)
gr.Info("Checkpoints uploaded successfully")
def fetch_model(exp_dir: str, repo: str, hf_token: str):
if not exp_dir:
exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR)
checkpoint_dir = os.path.join(exp_dir, "checkpoints")
gr.Info("Fetching model")
files = ["README.md", "config.json", "model.safetensors"]
snapshot_download(
repo, token=hf_token, local_dir=checkpoint_dir, allow_patterns=files
)
gr.Info("Model fetched successfully")
return exp_dir
def fetch_checkpoints(exp_dir: str, repo: str, hf_token: str):
if not exp_dir:
exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR)
checkpoint_dir = os.path.join(exp_dir, "checkpoints")
gr.Info("Fetching checkpoints")
snapshot_download(repo, token=hf_token, local_dir=checkpoint_dir)
gr.Info("Checkpoints fetched successfully")
return exp_dir
class TrainTab:
def __init__(self):
pass
def ui(self):
gr.Markdown("# Training")
gr.Markdown(
"You can start training the model by clicking the button below. "
f"Each time you click the button, the model will train for {TRAINING_EPOCHS} epochs, which takes about 3 minutes on ZeroGPU (A100). "
)
with gr.Row():
self.train_btn = gr.Button(value="Train", variant="primary")
self.result = gr.Textbox(label="Training Result", lines=3)
gr.Markdown("## Sync Model and Checkpoints with Hugging Face")
gr.Markdown(
"You can upload the trained model and checkpoints to Hugging Face for sharing or further training."
)
self.repo = gr.Textbox(label="Repository ID", placeholder="username/repo")
with gr.Row():
self.upload_model_btn = gr.Button(value="Upload Model", variant="primary")
self.upload_checkpoints_btn = gr.Button(
value="Upload Checkpoints", variant="primary"
)
with gr.Row():
self.fetch_mode_btn = gr.Button(value="Fetch Model", variant="primary")
self.fetch_checkpoints_btn = gr.Button(
value="Fetch Checkpoints", variant="primary"
)
def build(self, exp_dir: gr.Textbox, hf_token: gr.Textbox):
self.train_btn.click(
fn=train_model,
inputs=[exp_dir],
outputs=[self.result],
)
self.upload_model_btn.click(
fn=upload_model,
inputs=[exp_dir, self.repo, hf_token],
)
self.upload_checkpoints_btn.click(
fn=upload_checkpoints,
inputs=[exp_dir, self.repo, hf_token],
)
self.fetch_mode_btn.click(
fn=fetch_model,
inputs=[exp_dir, self.repo, hf_token],
outputs=[exp_dir],
)
self.fetch_checkpoints_btn.click(
fn=fetch_checkpoints,
inputs=[exp_dir, self.repo, hf_token],
outputs=[exp_dir],
)