vampnet / scripts /exp /fine_tune.py
pseeth's picture
new unlooper ui (#9)
d9bad6c
raw
history blame
2.38 kB
import argbind
from pathlib import Path
import yaml
from typing import List
"""example output: (yaml)
"""
@argbind.bind(without_prefix=True, positional=True)
def fine_tune(audio_files_or_folders: List[str], name: str):
conf_dir = Path("conf")
assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
conf_dir = conf_dir / "generated"
conf_dir.mkdir(exist_ok=True)
finetune_dir = conf_dir / name
finetune_dir.mkdir(exist_ok=True)
finetune_c2f_conf = {
"$include": ["conf/lora/lora.yml"],
"fine_tune": True,
"train/AudioLoader.sources": audio_files_or_folders,
"val/AudioLoader.sources": audio_files_or_folders,
"VampNet.n_codebooks": 14,
"VampNet.n_conditioning_codebooks": 4,
"VampNet.embedding_dim": 1280,
"VampNet.n_layers": 16,
"VampNet.n_heads": 20,
"AudioDataset.duration": 3.0,
"AudioDataset.loudness_cutoff": -40.0,
"save_path": f"./runs/{name}/c2f",
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
}
finetune_coarse_conf = {
"$include": ["conf/lora/lora.yml"],
"fine_tune": True,
"train/AudioLoader.sources": audio_files_or_folders,
"val/AudioLoader.sources": audio_files_or_folders,
"save_path": f"./runs/{name}/coarse",
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
}
interface_conf = {
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
"AudioLoader.sources": [audio_files_or_folders],
}
# save the confs
with open(finetune_dir / "c2f.yml", "w") as f:
yaml.dump(finetune_c2f_conf, f)
with open(finetune_dir / "coarse.yml", "w") as f:
yaml.dump(finetune_coarse_conf, f)
with open(finetune_dir / "interface.yml", "w") as f:
yaml.dump(interface_conf, f)
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
fine_tune()