Spaces:
Running
on
T4
Running
on
T4
Merge branch 'main' of github.com:descriptinc/lyrebird-vampnet into main
Browse files- requirements.txt +1 -1
- scripts/exp/train.py +5 -1
- setup.py +3 -1
- vampnet/modules/base.py +2 -2
requirements.txt
CHANGED
@@ -2,12 +2,12 @@ argbind>=0.3.1
|
|
2 |
pytorch-ignite
|
3 |
rich
|
4 |
audiotools @ git+https://github.com/descriptinc/[email protected]
|
|
|
5 |
tqdm
|
6 |
tensorboard
|
7 |
google-cloud-logging==2.2.0
|
8 |
pytest
|
9 |
pytest-cov
|
10 |
-
papaya_client @ git+https://github.com/descriptinc/lyrebird-papaya.git@master
|
11 |
pynvml
|
12 |
psutil
|
13 |
pandas
|
|
|
2 |
pytorch-ignite
|
3 |
rich
|
4 |
audiotools @ git+https://github.com/descriptinc/[email protected]
|
5 |
+
lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main
|
6 |
tqdm
|
7 |
tensorboard
|
8 |
google-cloud-logging==2.2.0
|
9 |
pytest
|
10 |
pytest-cov
|
|
|
11 |
pynvml
|
12 |
psutil
|
13 |
pandas
|
scripts/exp/train.py
CHANGED
@@ -59,7 +59,7 @@ IGNORE_INDEX = -100
|
|
59 |
@argbind.bind("train", "val", without_prefix=True)
|
60 |
def build_transform():
|
61 |
transform = transforms.Compose(
|
62 |
-
tfm.VolumeNorm(("uniform", -32, -
|
63 |
tfm.VolumeChange(("uniform", -6, 3)),
|
64 |
tfm.RescaleAudio(),
|
65 |
)
|
@@ -250,6 +250,7 @@ def train(
|
|
250 |
max_epochs: int = int(100e3),
|
251 |
epoch_length: int = 1000,
|
252 |
save_audio_epochs: int = 10,
|
|
|
253 |
batch_size: int = 48,
|
254 |
grad_acc_steps: int = 1,
|
255 |
val_idx: list = [0, 1, 2, 3, 4],
|
@@ -506,6 +507,9 @@ def train(
|
|
506 |
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
507 |
self.print(f"Saving to {str(Path('.').absolute())}")
|
508 |
|
|
|
|
|
|
|
509 |
if self.is_best(engine, loss_key):
|
510 |
self.print(f"Best model so far")
|
511 |
tags.append("best")
|
|
|
59 |
@argbind.bind("train", "val", without_prefix=True)
|
60 |
def build_transform():
|
61 |
transform = transforms.Compose(
|
62 |
+
tfm.VolumeNorm(("uniform", -32, -20)),
|
63 |
tfm.VolumeChange(("uniform", -6, 3)),
|
64 |
tfm.RescaleAudio(),
|
65 |
)
|
|
|
250 |
max_epochs: int = int(100e3),
|
251 |
epoch_length: int = 1000,
|
252 |
save_audio_epochs: int = 10,
|
253 |
+
save_epochs: list = [10, 50, 100, 200, 300, 400,],
|
254 |
batch_size: int = 48,
|
255 |
grad_acc_steps: int = 1,
|
256 |
val_idx: list = [0, 1, 2, 3, 4],
|
|
|
507 |
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
508 |
self.print(f"Saving to {str(Path('.').absolute())}")
|
509 |
|
510 |
+
if self.state.epoch in save_epochs:
|
511 |
+
tags.append(f"epoch={self.state.epoch}")
|
512 |
+
|
513 |
if self.is_best(engine, loss_key):
|
514 |
self.print(f"Best model so far")
|
515 |
tags.append("best")
|
setup.py
CHANGED
@@ -30,11 +30,13 @@ setup(
|
|
30 |
"argbind>=0.3.2",
|
31 |
"pytorch-ignite",
|
32 |
"rich",
|
33 |
-
"audiotools @ git+https://github.com/descriptinc/[email protected].
|
|
|
34 |
"tqdm",
|
35 |
"tensorboard",
|
36 |
"google-cloud-logging==2.2.0",
|
37 |
"torchmetrics>=0.7.3",
|
38 |
"einops",
|
|
|
39 |
],
|
40 |
)
|
|
|
30 |
"argbind>=0.3.2",
|
31 |
"pytorch-ignite",
|
32 |
"rich",
|
33 |
+
"audiotools @ git+https://github.com/descriptinc/[email protected].3",
|
34 |
+
"lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main",
|
35 |
"tqdm",
|
36 |
"tensorboard",
|
37 |
"google-cloud-logging==2.2.0",
|
38 |
"torchmetrics>=0.7.3",
|
39 |
"einops",
|
40 |
+
"flash-attn",
|
41 |
],
|
42 |
)
|
vampnet/modules/base.py
CHANGED
@@ -153,7 +153,7 @@ class VampBase(at.ml.BaseModel):
|
|
153 |
sampling_steps: int = 12,
|
154 |
start_tokens: Optional[torch.Tensor] = None,
|
155 |
mask: Optional[torch.Tensor] = None,
|
156 |
-
temperature: Union[float, Tuple[float, float]] =
|
157 |
top_k: int = None,
|
158 |
sample: str = "gumbel",
|
159 |
renoise_mode: str = "start",
|
@@ -262,7 +262,7 @@ class VampBase(at.ml.BaseModel):
|
|
262 |
sampling_steps: int = 24,
|
263 |
start_tokens: Optional[torch.Tensor] = None,
|
264 |
mask: Optional[torch.Tensor] = None,
|
265 |
-
temperature: Union[float, Tuple[float, float]] =
|
266 |
top_k: int = None,
|
267 |
sample: str = "multinomial",
|
268 |
typical_filtering=False,
|
|
|
153 |
sampling_steps: int = 12,
|
154 |
start_tokens: Optional[torch.Tensor] = None,
|
155 |
mask: Optional[torch.Tensor] = None,
|
156 |
+
temperature: Union[float, Tuple[float, float]] = 0.8,
|
157 |
top_k: int = None,
|
158 |
sample: str = "gumbel",
|
159 |
renoise_mode: str = "start",
|
|
|
262 |
sampling_steps: int = 24,
|
263 |
start_tokens: Optional[torch.Tensor] = None,
|
264 |
mask: Optional[torch.Tensor] = None,
|
265 |
+
temperature: Union[float, Tuple[float, float]] = 0.8,
|
266 |
top_k: int = None,
|
267 |
sample: str = "multinomial",
|
268 |
typical_filtering=False,
|