Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Commit
β’
514015e
1
Parent(s):
b6b5d48
Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from lvdm.models.modules.lora import net_load_lora
|
6 |
+
from lvdm.utils.common_utils import instantiate_from_config
|
7 |
+
|
8 |
+
|
9 |
+
# ------------------------------------------------------------------------------------------
|
10 |
+
def load_model(config, ckpt_path, gpu_id=None, inject_lora=False, lora_scale=1.0, lora_path=''):
|
11 |
+
print(f"Loading model from {ckpt_path}")
|
12 |
+
|
13 |
+
# load sd
|
14 |
+
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
15 |
+
try:
|
16 |
+
global_step = pl_sd["global_step"]
|
17 |
+
epoch = pl_sd["epoch"]
|
18 |
+
except:
|
19 |
+
global_step = -1
|
20 |
+
epoch = -1
|
21 |
+
|
22 |
+
# load sd to model
|
23 |
+
try:
|
24 |
+
sd = pl_sd["state_dict"]
|
25 |
+
except:
|
26 |
+
sd = pl_sd
|
27 |
+
model = instantiate_from_config(config.model)
|
28 |
+
model.load_state_dict(sd, strict=True)
|
29 |
+
|
30 |
+
if inject_lora:
|
31 |
+
net_load_lora(model, lora_path, alpha=lora_scale)
|
32 |
+
|
33 |
+
# move to device & eval
|
34 |
+
if gpu_id is not None:
|
35 |
+
model.to(f"cuda:{gpu_id}")
|
36 |
+
else:
|
37 |
+
model.cuda()
|
38 |
+
model.eval()
|
39 |
+
|
40 |
+
return model, global_step, epoch
|
41 |
+
|
42 |
+
|
43 |
+
# ------------------------------------------------------------------------------------------
|
44 |
+
@torch.no_grad()
|
45 |
+
def get_conditions(prompts, model, batch_size, cond_fps=None,):
|
46 |
+
|
47 |
+
if isinstance(prompts, str) or isinstance(prompts, int):
|
48 |
+
prompts = [prompts]
|
49 |
+
if isinstance(prompts, list):
|
50 |
+
if len(prompts) == 1:
|
51 |
+
prompts = prompts * batch_size
|
52 |
+
elif len(prompts) == batch_size:
|
53 |
+
pass
|
54 |
+
else:
|
55 |
+
raise ValueError(f"invalid prompts length: {len(prompts)}")
|
56 |
+
else:
|
57 |
+
raise ValueError(f"invalid prompts: {prompts}")
|
58 |
+
assert(len(prompts) == batch_size)
|
59 |
+
|
60 |
+
# content condition: text / class label
|
61 |
+
c = model.get_learned_conditioning(prompts)
|
62 |
+
key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn'
|
63 |
+
c = {key: [c]}
|
64 |
+
|
65 |
+
# temporal condition: fps
|
66 |
+
if getattr(model, 'cond_stage2_config', None) is not None:
|
67 |
+
if model.cond_stage2_key == "temporal_context":
|
68 |
+
assert(cond_fps is not None)
|
69 |
+
batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)}
|
70 |
+
fps_embd = model.cond_stage2_model(batch)
|
71 |
+
c[model.cond_stage2_key] = fps_embd
|
72 |
+
|
73 |
+
return c
|
74 |
+
|
75 |
+
|
76 |
+
# ------------------------------------------------------------------------------------------
|
77 |
+
def make_model_input_shape(model, batch_size, T=None):
|
78 |
+
image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size
|
79 |
+
C = model.model.diffusion_model.in_channels
|
80 |
+
if T is None:
|
81 |
+
T = model.model.diffusion_model.temporal_length
|
82 |
+
shape = [batch_size, C, T, *image_size]
|
83 |
+
return shape
|
84 |
+
|
85 |
+
|
86 |
+
# ------------------------------------------------------------------------------------------
|
87 |
+
def custom_to_pil(x):
|
88 |
+
x = x.detach().cpu()
|
89 |
+
x = torch.clamp(x, -1., 1.)
|
90 |
+
x = (x + 1.) / 2.
|
91 |
+
x = x.permute(1, 2, 0).numpy()
|
92 |
+
x = (255 * x).astype(np.uint8)
|
93 |
+
x = Image.fromarray(x)
|
94 |
+
if not x.mode == "RGB":
|
95 |
+
x = x.convert("RGB")
|
96 |
+
return x
|
97 |
+
|
98 |
+
def torch_to_np(x):
|
99 |
+
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
100 |
+
sample = x.detach().cpu()
|
101 |
+
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
102 |
+
if sample.dim() == 5:
|
103 |
+
sample = sample.permute(0, 2, 3, 4, 1)
|
104 |
+
else:
|
105 |
+
sample = sample.permute(0, 2, 3, 1)
|
106 |
+
sample = sample.contiguous()
|
107 |
+
return sample
|
108 |
+
|
109 |
+
def make_sample_dir(opt, global_step=None, epoch=None):
|
110 |
+
if not getattr(opt, 'not_automatic_logdir', False):
|
111 |
+
gs_str = f"globalstep{global_step:09}" if global_step is not None else "None"
|
112 |
+
e_str = f"epoch{epoch:06}" if epoch is not None else "None"
|
113 |
+
ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}")
|
114 |
+
|
115 |
+
# subdir name
|
116 |
+
if opt.prompt_file is not None:
|
117 |
+
subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}"
|
118 |
+
else:
|
119 |
+
subdir = f"prompt_{opt.prompt[:10]}"
|
120 |
+
subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps"
|
121 |
+
subdir += f"_CfgScale{opt.scale}"
|
122 |
+
if opt.cond_fps is not None:
|
123 |
+
subdir += f"_fps{opt.cond_fps}"
|
124 |
+
if opt.seed is not None:
|
125 |
+
subdir += f"_seed{opt.seed}"
|
126 |
+
|
127 |
+
return os.path.join(ckpt_dir, subdir)
|
128 |
+
else:
|
129 |
+
return opt.logdir
|