Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
405226b
1
Parent(s):
88c78e1
use torch.compile for training
Browse files- scripts/exp/train.py +7 -5
scripts/exp/train.py
CHANGED
@@ -485,7 +485,6 @@ def load(
|
|
485 |
save_path: str,
|
486 |
resume: bool = False,
|
487 |
tag: str = "latest",
|
488 |
-
load_weights: bool = False,
|
489 |
fine_tune_checkpoint: Optional[str] = None,
|
490 |
grad_clip_val: float = 5.0,
|
491 |
) -> State:
|
@@ -498,7 +497,7 @@ def load(
|
|
498 |
kwargs = {
|
499 |
"folder": f"{save_path}/{tag}",
|
500 |
"map_location": "cpu",
|
501 |
-
"package":
|
502 |
}
|
503 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
@@ -511,11 +510,14 @@ def load(
|
|
511 |
|
512 |
if args["fine_tune"]:
|
513 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
514 |
-
model =
|
515 |
-
|
|
|
|
|
|
|
516 |
|
517 |
-
model = VampNet() if model is None else model
|
518 |
|
|
|
519 |
model = accel.prepare_model(model)
|
520 |
|
521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
485 |
save_path: str,
|
486 |
resume: bool = False,
|
487 |
tag: str = "latest",
|
|
|
488 |
fine_tune_checkpoint: Optional[str] = None,
|
489 |
grad_clip_val: float = 5.0,
|
490 |
) -> State:
|
|
|
497 |
kwargs = {
|
498 |
"folder": f"{save_path}/{tag}",
|
499 |
"map_location": "cpu",
|
500 |
+
"package": False,
|
501 |
}
|
502 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
503 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
|
510 |
|
511 |
if args["fine_tune"]:
|
512 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
513 |
+
model = torch.compile(
|
514 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
515 |
+
map_location="cpu",
|
516 |
+
)
|
517 |
+
)
|
518 |
|
|
|
519 |
|
520 |
+
model = torch.compile(VampNet()) if model is None else model
|
521 |
model = accel.prepare_model(model)
|
522 |
|
523 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|