alvanlii commited on
Commit
c85686d
1 Parent(s): b2d15ad

Update weights to be from same repo

Browse files
Files changed (1) hide show
  1. cool_models.py +2 -11
cool_models.py CHANGED
@@ -13,9 +13,6 @@ STEPS = 100
13
  USE_DDPM = False
14
  USE_DDIM = False
15
  USE_CPU = False
16
- BERT_PATH = "./weights/bert.pt"
17
- KL_PATH = "./weights/kl-f8.pt"
18
- INPAINT_PATH = "./weights/inpaint.pt"
19
  CLIP_SEG_PATH = './weights/rd64-uni.pth'
20
  CLIP_GUIDANCE = False
21
 
@@ -79,10 +76,7 @@ def make_models():
79
 
80
 
81
  model, diffusion = create_model_and_diffusion(**model_config)
82
-
83
- # model.from_pretrained("alvanlii/rdm_inpaint")
84
  model.load_state_dict(model_state_dict, strict=False)
85
- # model.save_pretrained("./weights/hf_inpaint")
86
 
87
  model.requires_grad_(CLIP_GUIDANCE).eval().to(device)
88
 
@@ -97,10 +91,7 @@ def make_models():
97
 
98
 
99
  lpips_model = lpips.LPIPS(net="vgg").to(device)
100
- hf_kl_path = hf_hub_download("alvanlii/rdm_kl", "kl-f8.pt")
101
-
102
- # kl_model_url = hf_hub_url("alvanlii/rdm_kl", "kl-f8.pt")
103
- # kl_cache_path = cached_download(kl_model_url, cache_dir=".")
104
 
105
  ldm = torch.load(hf_kl_path, map_location="cpu")
106
 
@@ -111,7 +102,7 @@ def make_models():
111
  set_requires_grad(ldm, CLIP_GUIDANCE)
112
 
113
  bert = BERTEmbedder(1280, 32)
114
- hf_bert_path = hf_hub_download("alvanlii/rdm_bert", 'bert.pt')
115
  # bert = BERTEmbedder.from_pretrained("alvanlii/rdm_bert")
116
  sd = torch.load(hf_bert_path, map_location="cpu")
117
  bert.load_state_dict(sd)
 
13
  USE_DDPM = False
14
  USE_DDIM = False
15
  USE_CPU = False
 
 
 
16
  CLIP_SEG_PATH = './weights/rd64-uni.pth'
17
  CLIP_GUIDANCE = False
18
 
 
76
 
77
 
78
  model, diffusion = create_model_and_diffusion(**model_config)
 
 
79
  model.load_state_dict(model_state_dict, strict=False)
 
80
 
81
  model.requires_grad_(CLIP_GUIDANCE).eval().to(device)
82
 
 
91
 
92
 
93
  lpips_model = lpips.LPIPS(net="vgg").to(device)
94
+ hf_kl_path = hf_hub_download("alvanlii/rdm_inpaint", "kl-f8.pt")
 
 
 
95
 
96
  ldm = torch.load(hf_kl_path, map_location="cpu")
97
 
 
102
  set_requires_grad(ldm, CLIP_GUIDANCE)
103
 
104
  bert = BERTEmbedder(1280, 32)
105
+ hf_bert_path = hf_hub_download("alvanlii/rdm_inpaint", 'bert.pt')
106
  # bert = BERTEmbedder.from_pretrained("alvanlii/rdm_bert")
107
  sd = torch.load(hf_bert_path, map_location="cpu")
108
  bert.load_state_dict(sd)