Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
31b771c
1
Parent(s):
a66dc9c
dropping torch.compile for now
Browse files- scripts/exp/train.py +10 -4
- scripts/utils/split_long_audio_file.py +34 -0
scripts/exp/train.py
CHANGED
@@ -29,6 +29,9 @@ from audiotools.ml.decorators import (
|
|
29 |
|
30 |
import loralib as lora
|
31 |
|
|
|
|
|
|
|
32 |
|
33 |
# Enable cudnn autotuner to speed up training
|
34 |
# (can be altered by the funcs.seed function)
|
@@ -510,14 +513,14 @@ def load(
|
|
510 |
|
511 |
if args["fine_tune"]:
|
512 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
513 |
-
model =
|
514 |
VampNet.load(location=Path(fine_tune_checkpoint),
|
515 |
map_location="cpu",
|
516 |
)
|
517 |
)
|
518 |
|
519 |
|
520 |
-
model =
|
521 |
model = accel.prepare_model(model)
|
522 |
|
523 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
@@ -601,7 +604,7 @@ def train(
|
|
601 |
accel=accel,
|
602 |
tracker=tracker,
|
603 |
save_path=save_path)
|
604 |
-
|
605 |
|
606 |
train_dataloader = accel.prepare_dataloader(
|
607 |
state.train_data,
|
@@ -616,13 +619,15 @@ def train(
|
|
616 |
num_workers=num_workers,
|
617 |
batch_size=batch_size,
|
618 |
collate_fn=state.val_data.collate,
|
619 |
-
persistent_workers=
|
620 |
)
|
|
|
621 |
|
622 |
|
623 |
|
624 |
if fine_tune:
|
625 |
lora.mark_only_lora_as_trainable(state.model)
|
|
|
626 |
|
627 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
628 |
# and only run when specific conditions are met.
|
@@ -637,6 +642,7 @@ def train(
|
|
637 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
638 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
639 |
|
|
|
640 |
with tracker.live:
|
641 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
642 |
train_loop(state, batch, accel)
|
|
|
29 |
|
30 |
import loralib as lora
|
31 |
|
32 |
+
import torch._dynamo
|
33 |
+
torch._dynamo.config.verbose=True
|
34 |
+
|
35 |
|
36 |
# Enable cudnn autotuner to speed up training
|
37 |
# (can be altered by the funcs.seed function)
|
|
|
513 |
|
514 |
if args["fine_tune"]:
|
515 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
516 |
+
model = (
|
517 |
VampNet.load(location=Path(fine_tune_checkpoint),
|
518 |
map_location="cpu",
|
519 |
)
|
520 |
)
|
521 |
|
522 |
|
523 |
+
model = VampNet() if model is None else model
|
524 |
model = accel.prepare_model(model)
|
525 |
|
526 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
604 |
accel=accel,
|
605 |
tracker=tracker,
|
606 |
save_path=save_path)
|
607 |
+
print("initialized state.")
|
608 |
|
609 |
train_dataloader = accel.prepare_dataloader(
|
610 |
state.train_data,
|
|
|
619 |
num_workers=num_workers,
|
620 |
batch_size=batch_size,
|
621 |
collate_fn=state.val_data.collate,
|
622 |
+
persistent_workers=num_workers > 0,
|
623 |
)
|
624 |
+
print("initialized dataloader.")
|
625 |
|
626 |
|
627 |
|
628 |
if fine_tune:
|
629 |
lora.mark_only_lora_as_trainable(state.model)
|
630 |
+
print("marked only lora as trainable.")
|
631 |
|
632 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
633 |
# and only run when specific conditions are met.
|
|
|
642 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
643 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
644 |
|
645 |
+
print("starting training loop.")
|
646 |
with tracker.live:
|
647 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
648 |
train_loop(state, batch, accel)
|
scripts/utils/split_long_audio_file.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argbind
|
3 |
+
|
4 |
+
import audiotools as at
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def split_long_audio_file(
|
10 |
+
file: str = None,
|
11 |
+
max_chunk_size_s: int = 60*10
|
12 |
+
):
|
13 |
+
file = Path(file)
|
14 |
+
output_dir = file.parent / file.stem
|
15 |
+
output_dir.mkdir()
|
16 |
+
|
17 |
+
sig = at.AudioSignal(file)
|
18 |
+
|
19 |
+
# split into chunks
|
20 |
+
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
21 |
+
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
22 |
+
preprocess=True))
|
23 |
+
):
|
24 |
+
sig.write(output_dir / f"{i}.wav")
|
25 |
+
|
26 |
+
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
27 |
+
|
28 |
+
return output_dir
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
args = argbind.parse_args()
|
32 |
+
|
33 |
+
with argbind.scope(args):
|
34 |
+
split_long_audio_file()
|