Spaces:
Running
on
T4
Running
on
T4
Hugo Flores Garcia
commited on
Commit
β’
75a7169
1
Parent(s):
13b04cf
efficient lora ckpts
Browse files- README.md +24 -38
- conf/{generated β generated-v0}/berta-goldman-speech/c2f.yml +0 -0
- conf/{generated β generated-v0}/berta-goldman-speech/coarse.yml +0 -0
- conf/{generated β generated-v0}/berta-goldman-speech/interface.yml +0 -0
- conf/{generated β generated-v0}/gamelan-xeno-canto/c2f.yml +0 -0
- conf/{generated β generated-v0}/gamelan-xeno-canto/coarse.yml +0 -0
- conf/{generated β generated-v0}/gamelan-xeno-canto/interface.yml +0 -0
- conf/{generated β generated-v0}/nasralla/c2f.yml +0 -0
- conf/{generated β generated-v0}/nasralla/coarse.yml +0 -0
- conf/{generated β generated-v0}/nasralla/interface.yml +0 -0
- conf/generated/musica-bolero-marimba/c2f.yml +18 -0
- conf/generated/musica-bolero-marimba/coarse.yml +11 -0
- conf/generated/musica-bolero-marimba/interface.yml +8 -0
- conf/generated/xeno-canto/c2f.yml +15 -0
- conf/generated/xeno-canto/coarse.yml +8 -0
- conf/generated/xeno-canto/interface.yml +7 -0
- conf/lora/lora.yml +2 -2
- conf/vampnet.yml +1 -1
- demo.py +7 -3
- scripts/exp/fine_tune.py +16 -19
- scripts/exp/train.py +19 -2
- vampnet/interface.py +55 -9
README.md
CHANGED
@@ -33,41 +33,6 @@ Config files are stored in the `conf/` folder.
|
|
33 |
|
34 |
Download the pretrained models from [this link](https://drive.google.com/file/d/1ZIBMJMt8QRE8MYYGjg4lH7v7BLbZneq2/view?usp=sharing). Then, extract the models to the `models/` folder.
|
35 |
|
36 |
-
# How the code is structured
|
37 |
-
|
38 |
-
This code was written fast to meet a publication deadline, so it can be messy and redundant at times. Currently working on cleaning it up.
|
39 |
-
|
40 |
-
```
|
41 |
-
βββ conf <- (conf files for training, finetuning, etc)
|
42 |
-
βββ demo.py <- (gradio UI for playing with vampnet)
|
43 |
-
βββ env <- (environment variables)
|
44 |
-
βΒ Β βββ env.sh
|
45 |
-
βββ models <- (extract pretrained models)
|
46 |
-
βΒ Β βββ spotdl
|
47 |
-
βΒ Β βΒ Β βββ c2f.pth <- (coarse2fine checkpoint)
|
48 |
-
βΒ Β βΒ Β βββ coarse.pth <- (coarse checkpoint)
|
49 |
-
βΒ Β βΒ Β βββ codec.pth <- (codec checkpoint)
|
50 |
-
βΒ Β βββ wavebeat.pth
|
51 |
-
βββ README.md
|
52 |
-
βββ scripts
|
53 |
-
βΒ Β βββ exp
|
54 |
-
βΒ Β βΒ Β βββ eval.py <- (eval script)
|
55 |
-
βΒ Β βΒ Β βββ train.py <- (training/finetuning script)
|
56 |
-
βΒ Β βββ utils
|
57 |
-
βββ vampnet
|
58 |
-
βΒ Β βββ beats.py <- (beat tracking logic)
|
59 |
-
βΒ Β βββ __init__.py
|
60 |
-
βΒ Β βββ interface.py <- (high-level programmatic interface)
|
61 |
-
βΒ Β βββ mask.py
|
62 |
-
βΒ Β βββ modules
|
63 |
-
βΒ Β βΒ Β βββ activations.py
|
64 |
-
βΒ Β βΒ Β βββ __init__.py
|
65 |
-
βΒ Β βΒ Β βββ layers.py
|
66 |
-
βΒ Β βΒ Β βββ transformer.py <- (architecture + sampling code)
|
67 |
-
βΒ Β βββ scheduler.py
|
68 |
-
βΒ Β βββ util.py
|
69 |
-
```
|
70 |
-
|
71 |
# Usage
|
72 |
|
73 |
First, you'll want to set up your environment
|
@@ -90,12 +55,33 @@ python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/ch
|
|
90 |
```
|
91 |
|
92 |
## Fine-tuning
|
93 |
-
To fine-tune a model,
|
94 |
-
|
|
|
95 |
```bash
|
96 |
-
python scripts/exp/
|
97 |
```
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
|
101 |
## Launching the Gradio Interface
|
|
|
33 |
|
34 |
Download the pretrained models from [this link](https://drive.google.com/file/d/1ZIBMJMt8QRE8MYYGjg4lH7v7BLbZneq2/view?usp=sharing). Then, extract the models to the `models/` folder.
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# Usage
|
37 |
|
38 |
First, you'll want to set up your environment
|
|
|
55 |
```
|
56 |
|
57 |
## Fine-tuning
|
58 |
+
To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`.
|
59 |
+
The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to fine-tune the interface.
|
60 |
+
|
61 |
```bash
|
62 |
+
python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" <fine_tune_name>
|
63 |
```
|
64 |
|
65 |
+
This will create a folder under `conf/<fine_tune_name>/` with the 3 configuration files.
|
66 |
+
|
67 |
+
The save_paths will be set to `runs/<fine_tune_name>/coarse` and `runs/<fine_tune_name>/c2f`.
|
68 |
+
|
69 |
+
launch the coarse job:
|
70 |
+
```bash
|
71 |
+
python scripts/exp/train.py --args.load conf/<fine_tune_name>/coarse.yml
|
72 |
+
```
|
73 |
+
|
74 |
+
this will save the coarse model to `runs/<fine_tune_name>/coarse/ckpt/best/`.
|
75 |
+
|
76 |
+
launch the c2f job:
|
77 |
+
```bash
|
78 |
+
python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
|
79 |
+
```
|
80 |
+
|
81 |
+
launch the interface:
|
82 |
+
```bash
|
83 |
+
python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
|
84 |
+
```
|
85 |
|
86 |
|
87 |
## Launching the Gradio Interface
|
conf/{generated β generated-v0}/berta-goldman-speech/c2f.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/berta-goldman-speech/coarse.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/berta-goldman-speech/interface.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/gamelan-xeno-canto/c2f.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/gamelan-xeno-canto/coarse.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/gamelan-xeno-canto/interface.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/nasralla/c2f.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/nasralla/coarse.yml
RENAMED
File without changes
|
conf/{generated β generated-v0}/nasralla/interface.yml
RENAMED
File without changes
|
conf/generated/musica-bolero-marimba/c2f.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/musica-bolero-marimba/c2f
|
13 |
+
train/AudioLoader.sources:
|
14 |
+
- /media/CHONK/hugo/loras/boleros
|
15 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
16 |
+
val/AudioLoader.sources:
|
17 |
+
- /media/CHONK/hugo/loras/boleros
|
18 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
conf/generated/musica-bolero-marimba/coarse.yml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/musica-bolero-marimba/coarse
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/loras/boleros
|
8 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/loras/boleros
|
11 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
conf/generated/musica-bolero-marimba/interface.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- /media/CHONK/hugo/loras/boleros
|
3 |
+
- /media/CHONK/hugo/loras/marimba-honduras
|
4 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
5 |
+
Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth
|
6 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
7 |
+
Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth
|
8 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/xeno-canto/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/xeno-canto/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/xeno-canto-2/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/xeno-canto/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/xeno-canto/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/xeno-canto-2/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/xeno-canto/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/xeno-canto-2/
|
3 |
+
Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/lora/lora.yml
CHANGED
@@ -13,10 +13,10 @@ NoamScheduler.warmup: 500
|
|
13 |
batch_size: 7
|
14 |
num_workers: 7
|
15 |
epoch_length: 100
|
16 |
-
save_audio_epochs:
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
-
max_epochs:
|
|
|
13 |
batch_size: 7
|
14 |
num_workers: 7
|
15 |
epoch_length: 100
|
16 |
+
save_audio_epochs: 10
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
20 |
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
AudioDataset.without_replacement: False
|
22 |
+
max_epochs: 500
|
conf/vampnet.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
|
2 |
-
codec_ckpt:
|
3 |
save_path: ckpt
|
4 |
max_epochs: 1000
|
5 |
epoch_length: 1000
|
|
|
1 |
|
2 |
+
codec_ckpt: ./models/spotdl/codec.pth
|
3 |
save_path: ckpt
|
4 |
max_epochs: 1000
|
5 |
epoch_length: 1000
|
demo.py
CHANGED
@@ -104,7 +104,11 @@ def _vamp(data, return_mask=False):
|
|
104 |
# save the mask as a txt file
|
105 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
106 |
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
zv, mask_z = interface.coarse_vamp(
|
109 |
z,
|
110 |
mask=mask,
|
@@ -354,17 +358,16 @@ with gr.Blocks() as demo:
|
|
354 |
value=0.0
|
355 |
)
|
356 |
|
357 |
-
vamp_button = gr.Button("vamp!!!")
|
358 |
|
359 |
# mask settings
|
360 |
with gr.Column():
|
|
|
361 |
output_audio = gr.Audio(
|
362 |
label="output audio",
|
363 |
interactive=False,
|
364 |
type="filepath"
|
365 |
)
|
366 |
|
367 |
-
use_as_input_button = gr.Button("use as input")
|
368 |
|
369 |
|
370 |
# with gr.Column():
|
@@ -397,6 +400,7 @@ with gr.Blocks() as demo:
|
|
397 |
label="vamp to download will appear here",
|
398 |
interactive=False
|
399 |
)
|
|
|
400 |
|
401 |
thank_you = gr.Markdown("")
|
402 |
|
|
|
104 |
# save the mask as a txt file
|
105 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
106 |
|
107 |
+
if data[topk] is not None:
|
108 |
+
top_k = data[topk] if data[topk] > 0 else None
|
109 |
+
else:
|
110 |
+
top_k = None
|
111 |
+
|
112 |
zv, mask_z = interface.coarse_vamp(
|
113 |
z,
|
114 |
mask=mask,
|
|
|
358 |
value=0.0
|
359 |
)
|
360 |
|
|
|
361 |
|
362 |
# mask settings
|
363 |
with gr.Column():
|
364 |
+
vamp_button = gr.Button("vamp!!!")
|
365 |
output_audio = gr.Audio(
|
366 |
label="output audio",
|
367 |
interactive=False,
|
368 |
type="filepath"
|
369 |
)
|
370 |
|
|
|
371 |
|
372 |
|
373 |
# with gr.Column():
|
|
|
400 |
label="vamp to download will appear here",
|
401 |
interactive=False
|
402 |
)
|
403 |
+
use_as_input_button = gr.Button("use output as input")
|
404 |
|
405 |
thank_you = gr.Markdown("")
|
406 |
|
scripts/exp/fine_tune.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import argbind
|
2 |
from pathlib import Path
|
3 |
import yaml
|
|
|
4 |
|
5 |
|
6 |
|
@@ -10,7 +11,7 @@ import yaml
|
|
10 |
"""
|
11 |
|
12 |
@argbind.bind(without_prefix=True, positional=True)
|
13 |
-
def fine_tune(
|
14 |
|
15 |
conf_dir = Path("conf")
|
16 |
assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
|
@@ -24,8 +25,8 @@ def fine_tune(audio_file_or_folder: str, name: str):
|
|
24 |
finetune_c2f_conf = {
|
25 |
"$include": ["conf/lora/lora.yml"],
|
26 |
"fine_tune": True,
|
27 |
-
"train/AudioLoader.sources":
|
28 |
-
"val/AudioLoader.sources":
|
29 |
"VampNet.n_codebooks": 14,
|
30 |
"VampNet.n_conditioning_codebooks": 4,
|
31 |
"VampNet.embedding_dim": 1280,
|
@@ -34,21 +35,27 @@ def fine_tune(audio_file_or_folder: str, name: str):
|
|
34 |
"AudioDataset.duration": 3.0,
|
35 |
"AudioDataset.loudness_cutoff": -40.0,
|
36 |
"save_path": f"./runs/{name}/c2f",
|
|
|
37 |
}
|
38 |
|
39 |
finetune_coarse_conf = {
|
40 |
"$include": ["conf/lora/lora.yml"],
|
41 |
"fine_tune": True,
|
42 |
-
"train/AudioLoader.sources":
|
43 |
-
"val/AudioLoader.sources":
|
44 |
"save_path": f"./runs/{name}/coarse",
|
|
|
45 |
}
|
46 |
|
47 |
interface_conf = {
|
48 |
-
"Interface.coarse_ckpt": f"./
|
49 |
-
"Interface.
|
|
|
|
|
|
|
|
|
50 |
"Interface.codec_ckpt": "./models/spotdl/codec.pth",
|
51 |
-
"AudioLoader.sources": [
|
52 |
}
|
53 |
|
54 |
# save the confs
|
@@ -61,18 +68,8 @@ def fine_tune(audio_file_or_folder: str, name: str):
|
|
61 |
with open(finetune_dir / "interface.yml", "w") as f:
|
62 |
yaml.dump(interface_conf, f)
|
63 |
|
64 |
-
# copy the starter weights to the save paths
|
65 |
-
import shutil
|
66 |
-
|
67 |
-
def pmkdir(path):
|
68 |
-
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
69 |
-
return path
|
70 |
-
|
71 |
-
shutil.copy("./models/spotdl/c2f.pth", pmkdir(f"./runs/{name}/c2f/starter/vampnet/weights.pth"))
|
72 |
-
shutil.copy("./models/spotdl/coarse.pth", pmkdir(f"./runs/{name}/coarse/starter/vampnet/weights.pth"))
|
73 |
-
|
74 |
|
75 |
-
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
args = argbind.parse_args()
|
|
|
1 |
import argbind
|
2 |
from pathlib import Path
|
3 |
import yaml
|
4 |
+
from typing import List
|
5 |
|
6 |
|
7 |
|
|
|
11 |
"""
|
12 |
|
13 |
@argbind.bind(without_prefix=True, positional=True)
|
14 |
+
def fine_tune(audio_files_or_folders: List[str], name: str):
|
15 |
|
16 |
conf_dir = Path("conf")
|
17 |
assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
|
|
|
25 |
finetune_c2f_conf = {
|
26 |
"$include": ["conf/lora/lora.yml"],
|
27 |
"fine_tune": True,
|
28 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
29 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
30 |
"VampNet.n_codebooks": 14,
|
31 |
"VampNet.n_conditioning_codebooks": 4,
|
32 |
"VampNet.embedding_dim": 1280,
|
|
|
35 |
"AudioDataset.duration": 3.0,
|
36 |
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
"save_path": f"./runs/{name}/c2f",
|
38 |
+
"fine_tune_checkpoint": "./models/spotdl/c2f.pth"
|
39 |
}
|
40 |
|
41 |
finetune_coarse_conf = {
|
42 |
"$include": ["conf/lora/lora.yml"],
|
43 |
"fine_tune": True,
|
44 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
"save_path": f"./runs/{name}/coarse",
|
47 |
+
"fine_tune_checkpoint": "./models/spotdl/coarse.pth"
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./models/spotdl/coarse.pth",
|
52 |
+
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
+
|
54 |
+
"Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth",
|
55 |
+
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
+
|
57 |
"Interface.codec_ckpt": "./models/spotdl/codec.pth",
|
58 |
+
"AudioLoader.sources": [audio_files_or_folders],
|
59 |
}
|
60 |
|
61 |
# save the confs
|
|
|
68 |
with open(finetune_dir / "interface.yml", "w") as f:
|
69 |
yaml.dump(interface_conf, f)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
args = argbind.parse_args()
|
scripts/exp/train.py
CHANGED
@@ -107,7 +107,11 @@ def load(
|
|
107 |
resume: bool = False,
|
108 |
tag: str = "latest",
|
109 |
load_weights: bool = False,
|
|
|
110 |
):
|
|
|
|
|
|
|
111 |
model, v_extra = None, {}
|
112 |
|
113 |
if resume:
|
@@ -123,8 +127,12 @@ def load(
|
|
123 |
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
124 |
)
|
125 |
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
model = VampNet() if model is None else model
|
129 |
|
130 |
model = accel.prepare_model(model)
|
@@ -460,6 +468,15 @@ def train(
|
|
460 |
self.print(f"Best model so far")
|
461 |
tags.append("best")
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
for tag in tags:
|
464 |
model_extra = {
|
465 |
"optimizer.pth": optimizer.state_dict(),
|
|
|
107 |
resume: bool = False,
|
108 |
tag: str = "latest",
|
109 |
load_weights: bool = False,
|
110 |
+
fine_tune_checkpoint: Optional[str] = None,
|
111 |
):
|
112 |
+
codec = LAC.load(args["codec_ckpt"], map_location="cpu")
|
113 |
+
codec.eval()
|
114 |
+
|
115 |
model, v_extra = None, {}
|
116 |
|
117 |
if resume:
|
|
|
127 |
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
128 |
)
|
129 |
|
130 |
+
|
131 |
+
if args["fine_tune"]:
|
132 |
+
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
133 |
+
model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
|
134 |
+
|
135 |
+
|
136 |
model = VampNet() if model is None else model
|
137 |
|
138 |
model = accel.prepare_model(model)
|
|
|
468 |
self.print(f"Best model so far")
|
469 |
tags.append("best")
|
470 |
|
471 |
+
if fine_tune:
|
472 |
+
for tag in tags:
|
473 |
+
# save the lora model
|
474 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
475 |
+
torch.save(
|
476 |
+
lora.lora_state_dict(accel.unwrap(model)),
|
477 |
+
f"{save_path}/{tag}/lora.pth"
|
478 |
+
)
|
479 |
+
|
480 |
for tag in tags:
|
481 |
model_extra = {
|
482 |
"optimizer.pth": optimizer.state_dict(),
|
vampnet/interface.py
CHANGED
@@ -21,12 +21,40 @@ def signal_concat(
|
|
21 |
|
22 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
class Interface(torch.nn.Module):
|
26 |
def __init__(
|
27 |
self,
|
28 |
coarse_ckpt: str = None,
|
|
|
29 |
coarse2fine_ckpt: str = None,
|
|
|
30 |
codec_ckpt: str = None,
|
31 |
wavebeat_ckpt: str = None,
|
32 |
device: str = "cpu",
|
@@ -40,18 +68,21 @@ class Interface(torch.nn.Module):
|
|
40 |
self.codec.to(device)
|
41 |
|
42 |
assert coarse_ckpt is not None, "must provide a coarse checkpoint"
|
43 |
-
self.coarse =
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
47 |
|
|
|
48 |
if coarse2fine_ckpt is not None:
|
49 |
-
self.c2f =
|
50 |
-
|
|
|
|
|
|
|
51 |
)
|
52 |
-
self.c2f.to(device)
|
53 |
-
self.c2f.eval()
|
54 |
-
self.c2f.chunk_size_s = self.s2t2s(coarse2fine_chunk_size_s)
|
55 |
else:
|
56 |
self.c2f = None
|
57 |
|
@@ -64,6 +95,21 @@ class Interface(torch.nn.Module):
|
|
64 |
|
65 |
self.device = device
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def s2t(self, seconds: float):
|
68 |
"""seconds to tokens"""
|
69 |
if isinstance(seconds, np.ndarray):
|
|
|
21 |
|
22 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
23 |
|
24 |
+
def _load_model(
|
25 |
+
ckpt: str,
|
26 |
+
lora_ckpt: str = None,
|
27 |
+
device: str = "cpu",
|
28 |
+
chunk_size_s: int = 10,
|
29 |
+
):
|
30 |
+
# we need to set strict to False if the model has lora weights to add later
|
31 |
+
model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
|
32 |
+
|
33 |
+
# load lora weights if needed
|
34 |
+
if lora_ckpt is not None:
|
35 |
+
if not Path(lora_ckpt).exists():
|
36 |
+
should_cont = input(
|
37 |
+
f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
|
38 |
+
)
|
39 |
+
if should_cont != "y":
|
40 |
+
raise Exception("aborting")
|
41 |
+
else:
|
42 |
+
model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
|
43 |
+
|
44 |
+
model.to(device)
|
45 |
+
model.eval()
|
46 |
+
model.chunk_size_s = chunk_size_s
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
|
51 |
class Interface(torch.nn.Module):
|
52 |
def __init__(
|
53 |
self,
|
54 |
coarse_ckpt: str = None,
|
55 |
+
coarse_lora_ckpt: str = None,
|
56 |
coarse2fine_ckpt: str = None,
|
57 |
+
coarse2fine_lora_ckpt: str = None,
|
58 |
codec_ckpt: str = None,
|
59 |
wavebeat_ckpt: str = None,
|
60 |
device: str = "cpu",
|
|
|
68 |
self.codec.to(device)
|
69 |
|
70 |
assert coarse_ckpt is not None, "must provide a coarse checkpoint"
|
71 |
+
self.coarse = _load_model(
|
72 |
+
ckpt=coarse_ckpt,
|
73 |
+
lora_ckpt=coarse_lora_ckpt,
|
74 |
+
device=device,
|
75 |
+
chunk_size_s=coarse_chunk_size_s,
|
76 |
+
)
|
77 |
|
78 |
+
# check if we have a coarse2fine ckpt
|
79 |
if coarse2fine_ckpt is not None:
|
80 |
+
self.c2f = _load_model(
|
81 |
+
ckpt=coarse2fine_ckpt,
|
82 |
+
lora_ckpt=coarse2fine_lora_ckpt,
|
83 |
+
device=device,
|
84 |
+
chunk_size_s=coarse2fine_chunk_size_s,
|
85 |
)
|
|
|
|
|
|
|
86 |
else:
|
87 |
self.c2f = None
|
88 |
|
|
|
95 |
|
96 |
self.device = device
|
97 |
|
98 |
+
def lora_load(
|
99 |
+
self,
|
100 |
+
coarse_lora_ckpt: str = None,
|
101 |
+
coarse2fine_lora_ckpt: str = None,
|
102 |
+
):
|
103 |
+
if coarse_lora_ckpt is not None:
|
104 |
+
self.coarse.to("cpu")
|
105 |
+
self.coarse.load_state_dict(torch.load(coarse_lora_ckpt, map_location="cpu"))
|
106 |
+
self.coarse.to(self.device)
|
107 |
+
if coarse2fine_lora_ckpt is not None:
|
108 |
+
self.c2f.to("cpu")
|
109 |
+
self.c2f.load_state_dict(torch.load(coarse2fine_lora_ckpt, map_location="cpu"))
|
110 |
+
self.c2f.to(self.device)
|
111 |
+
|
112 |
+
|
113 |
def s2t(self, seconds: float):
|
114 |
"""seconds to tokens"""
|
115 |
if isinstance(seconds, np.ndarray):
|