Spaces:
Runtime error
Runtime error
Update weights to be from same repo
Browse files- cool_models.py +2 -11
cool_models.py
CHANGED
@@ -13,9 +13,6 @@ STEPS = 100
|
|
13 |
USE_DDPM = False
|
14 |
USE_DDIM = False
|
15 |
USE_CPU = False
|
16 |
-
BERT_PATH = "./weights/bert.pt"
|
17 |
-
KL_PATH = "./weights/kl-f8.pt"
|
18 |
-
INPAINT_PATH = "./weights/inpaint.pt"
|
19 |
CLIP_SEG_PATH = './weights/rd64-uni.pth'
|
20 |
CLIP_GUIDANCE = False
|
21 |
|
@@ -79,10 +76,7 @@ def make_models():
|
|
79 |
|
80 |
|
81 |
model, diffusion = create_model_and_diffusion(**model_config)
|
82 |
-
|
83 |
-
# model.from_pretrained("alvanlii/rdm_inpaint")
|
84 |
model.load_state_dict(model_state_dict, strict=False)
|
85 |
-
# model.save_pretrained("./weights/hf_inpaint")
|
86 |
|
87 |
model.requires_grad_(CLIP_GUIDANCE).eval().to(device)
|
88 |
|
@@ -97,10 +91,7 @@ def make_models():
|
|
97 |
|
98 |
|
99 |
lpips_model = lpips.LPIPS(net="vgg").to(device)
|
100 |
-
hf_kl_path = hf_hub_download("alvanlii/
|
101 |
-
|
102 |
-
# kl_model_url = hf_hub_url("alvanlii/rdm_kl", "kl-f8.pt")
|
103 |
-
# kl_cache_path = cached_download(kl_model_url, cache_dir=".")
|
104 |
|
105 |
ldm = torch.load(hf_kl_path, map_location="cpu")
|
106 |
|
@@ -111,7 +102,7 @@ def make_models():
|
|
111 |
set_requires_grad(ldm, CLIP_GUIDANCE)
|
112 |
|
113 |
bert = BERTEmbedder(1280, 32)
|
114 |
-
hf_bert_path = hf_hub_download("alvanlii/
|
115 |
# bert = BERTEmbedder.from_pretrained("alvanlii/rdm_bert")
|
116 |
sd = torch.load(hf_bert_path, map_location="cpu")
|
117 |
bert.load_state_dict(sd)
|
|
|
13 |
USE_DDPM = False
|
14 |
USE_DDIM = False
|
15 |
USE_CPU = False
|
|
|
|
|
|
|
16 |
CLIP_SEG_PATH = './weights/rd64-uni.pth'
|
17 |
CLIP_GUIDANCE = False
|
18 |
|
|
|
76 |
|
77 |
|
78 |
model, diffusion = create_model_and_diffusion(**model_config)
|
|
|
|
|
79 |
model.load_state_dict(model_state_dict, strict=False)
|
|
|
80 |
|
81 |
model.requires_grad_(CLIP_GUIDANCE).eval().to(device)
|
82 |
|
|
|
91 |
|
92 |
|
93 |
lpips_model = lpips.LPIPS(net="vgg").to(device)
|
94 |
+
hf_kl_path = hf_hub_download("alvanlii/rdm_inpaint", "kl-f8.pt")
|
|
|
|
|
|
|
95 |
|
96 |
ldm = torch.load(hf_kl_path, map_location="cpu")
|
97 |
|
|
|
102 |
set_requires_grad(ldm, CLIP_GUIDANCE)
|
103 |
|
104 |
bert = BERTEmbedder(1280, 32)
|
105 |
+
hf_bert_path = hf_hub_download("alvanlii/rdm_inpaint", 'bert.pt')
|
106 |
# bert = BERTEmbedder.from_pretrained("alvanlii/rdm_bert")
|
107 |
sd = torch.load(hf_bert_path, map_location="cpu")
|
108 |
bert.load_state_dict(sd)
|