Spaces:
Sleeping
Sleeping
DveloperY0115
commited on
Commit
•
ddec8ca
1
Parent(s):
85a3747
Set map_location for pre-trained weights
Browse files
app.py
CHANGED
@@ -4,9 +4,6 @@ app.py
|
|
4 |
An interactive demo of text-guided shape generation.
|
5 |
"""
|
6 |
|
7 |
-
import os
|
8 |
-
os.system("pip install -e ./custom_wheels/salad-0.1-py3-none-any.whl")
|
9 |
-
|
10 |
from pathlib import Path
|
11 |
from typing import Literal
|
12 |
|
@@ -32,7 +29,10 @@ def load_model(
|
|
32 |
checkpoint_dir = Path(__file__).parent / "checkpoints"
|
33 |
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
|
34 |
model = hydra.utils.instantiate(c)
|
35 |
-
ckpt = torch.load(
|
|
|
|
|
|
|
36 |
model.load_state_dict(ckpt)
|
37 |
model.eval()
|
38 |
for p in model.parameters(): p.requires_grad_(False)
|
|
|
4 |
An interactive demo of text-guided shape generation.
|
5 |
"""
|
6 |
|
|
|
|
|
|
|
7 |
from pathlib import Path
|
8 |
from typing import Literal
|
9 |
|
|
|
29 |
checkpoint_dir = Path(__file__).parent / "checkpoints"
|
30 |
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
|
31 |
model = hydra.utils.instantiate(c)
|
32 |
+
ckpt = torch.load(
|
33 |
+
checkpoint_dir / f"{model_class}/state_only.ckpt",
|
34 |
+
map_location=device,
|
35 |
+
)
|
36 |
model.load_state_dict(ckpt)
|
37 |
model.eval()
|
38 |
for p in model.parameters(): p.requires_grad_(False)
|