Spaces:
Sleeping
Sleeping
kyleleey
commited on
Commit
•
cd3b424
1
Parent(s):
2b1cca8
remove unused pkgs
Browse files- requirements.txt +0 -2
- video3d/diffusion/sd.py +0 -252
- video3d/diffusion/sd_utils.py +0 -123
- video3d/diffusion/vsd.py +0 -323
- video3d/model_ddp.py +2 -199
requirements.txt
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
ConfigArgParse==1.5.3
|
2 |
-
core==1.0.1
|
3 |
-
diffusers==0.20.0
|
4 |
einops==0.4.1
|
5 |
faiss==1.7.3
|
6 |
fire==0.5.0
|
|
|
1 |
ConfigArgParse==1.5.3
|
|
|
|
|
2 |
einops==0.4.1
|
3 |
faiss==1.7.3
|
4 |
fire==0.5.0
|
video3d/diffusion/sd.py
DELETED
@@ -1,252 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
# os.environ['HUGGINGFACE_HUB_CACHE'] = '/work/tomj/cache/huggingface_hub'
|
3 |
-
# os.environ['HF_HOME'] = '/work/tomj/cache/huggingface_hub'
|
4 |
-
os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
|
5 |
-
os.environ['HF_HOME'] = '/viscam/u/zzli'
|
6 |
-
|
7 |
-
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
8 |
-
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
|
9 |
-
|
10 |
-
# Suppress partial model loading warning
|
11 |
-
logging.set_verbosity_error()
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import torch.nn as nn
|
15 |
-
import torch.nn.functional as F
|
16 |
-
|
17 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
18 |
-
|
19 |
-
class SpecifyGradient(torch.autograd.Function):
|
20 |
-
@staticmethod
|
21 |
-
@custom_fwd
|
22 |
-
def forward(ctx, input_tensor, gt_grad):
|
23 |
-
ctx.save_for_backward(gt_grad)
|
24 |
-
return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # Dummy loss value
|
25 |
-
|
26 |
-
@staticmethod
|
27 |
-
@custom_bwd
|
28 |
-
def backward(ctx, grad):
|
29 |
-
gt_grad, = ctx.saved_tensors
|
30 |
-
batch_size = len(gt_grad)
|
31 |
-
return gt_grad / batch_size, None
|
32 |
-
|
33 |
-
def seed_everything(seed):
|
34 |
-
torch.manual_seed(seed)
|
35 |
-
torch.cuda.manual_seed(seed)
|
36 |
-
|
37 |
-
|
38 |
-
class StableDiffusion(nn.Module):
|
39 |
-
def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32):
|
40 |
-
super().__init__()
|
41 |
-
|
42 |
-
self.device = device
|
43 |
-
self.sd_version = sd_version
|
44 |
-
self.torch_dtype = torch_dtype
|
45 |
-
|
46 |
-
print(f'[INFO] loading stable diffusion...')
|
47 |
-
|
48 |
-
if hf_key is not None:
|
49 |
-
print(f'[INFO] using hugging face custom model key: {hf_key}')
|
50 |
-
model_key = hf_key
|
51 |
-
elif self.sd_version == '2.1':
|
52 |
-
model_key = "stabilityai/stable-diffusion-2-1-base"
|
53 |
-
elif self.sd_version == '2.0':
|
54 |
-
model_key = "stabilityai/stable-diffusion-2-base"
|
55 |
-
elif self.sd_version == '1.5':
|
56 |
-
model_key = "runwayml/stable-diffusion-v1-5"
|
57 |
-
else:
|
58 |
-
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
|
59 |
-
|
60 |
-
# Create model
|
61 |
-
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
|
62 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
|
63 |
-
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
|
64 |
-
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
|
65 |
-
|
66 |
-
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
67 |
-
# self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
68 |
-
|
69 |
-
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
70 |
-
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
71 |
-
|
72 |
-
print(f'[INFO] loaded stable diffusion!')
|
73 |
-
|
74 |
-
def get_text_embeds(self, prompt, negative_prompt):
|
75 |
-
# prompt, negative_prompt: [str]
|
76 |
-
|
77 |
-
# Tokenize text and get embeddings
|
78 |
-
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
79 |
-
|
80 |
-
with torch.no_grad():
|
81 |
-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
82 |
-
|
83 |
-
# Do the same for unconditional embeddings
|
84 |
-
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
85 |
-
|
86 |
-
with torch.no_grad():
|
87 |
-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
88 |
-
|
89 |
-
# Cat for final embeddings
|
90 |
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
91 |
-
return text_embeddings
|
92 |
-
|
93 |
-
def train_step(self, text_embeddings, pred_rgb,
|
94 |
-
guidance_scale=100, loss_weight=1.0, min_step_pct=0.02, max_step_pct=0.98, return_aux=False):
|
95 |
-
pred_rgb = pred_rgb.to(self.torch_dtype)
|
96 |
-
text_embeddings = text_embeddings.to(self.torch_dtype)
|
97 |
-
b = pred_rgb.shape[0]
|
98 |
-
|
99 |
-
# interp to 512x512 to be fed into vae.
|
100 |
-
|
101 |
-
# _t = time.time()
|
102 |
-
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
103 |
-
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
|
104 |
-
|
105 |
-
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
106 |
-
min_step = int(self.num_train_timesteps * min_step_pct)
|
107 |
-
max_step = int(self.num_train_timesteps * max_step_pct)
|
108 |
-
t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
|
109 |
-
|
110 |
-
# encode image into latents with vae, requires grad!
|
111 |
-
# _t = time.time()
|
112 |
-
latents = self.encode_imgs(pred_rgb_512)
|
113 |
-
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
|
114 |
-
|
115 |
-
# predict the noise residual with unet, NO grad!
|
116 |
-
# _t = time.time()
|
117 |
-
with torch.no_grad():
|
118 |
-
# add noise
|
119 |
-
noise = torch.randn_like(latents)
|
120 |
-
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
121 |
-
# pred noise
|
122 |
-
latent_model_input = torch.cat([latents_noisy] * 2)
|
123 |
-
t_input = torch.cat([t, t])
|
124 |
-
noise_pred = self.unet(latent_model_input, t_input, encoder_hidden_states=text_embeddings).sample
|
125 |
-
# torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
|
126 |
-
|
127 |
-
# perform guidance (high scale from paper!)
|
128 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
129 |
-
# noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
130 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
131 |
-
|
132 |
-
# w(t), sigma_t^2
|
133 |
-
w = (1 - self.alphas[t])
|
134 |
-
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
135 |
-
grad = loss_weight * w[:, None, None, None] * (noise_pred - noise)
|
136 |
-
|
137 |
-
# clip grad for stable training?
|
138 |
-
# grad = grad.clamp(-10, 10)
|
139 |
-
grad = torch.nan_to_num(grad)
|
140 |
-
|
141 |
-
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
142 |
-
# _t = time.time()
|
143 |
-
# loss = SpecifyGradient.apply(latents, grad)
|
144 |
-
# torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
|
145 |
-
|
146 |
-
targets = (latents - grad).detach()
|
147 |
-
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
|
148 |
-
|
149 |
-
if return_aux:
|
150 |
-
aux = {'grad': grad, 't': t, 'w': w}
|
151 |
-
return loss, aux
|
152 |
-
else:
|
153 |
-
return loss
|
154 |
-
|
155 |
-
|
156 |
-
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
157 |
-
|
158 |
-
if latents is None:
|
159 |
-
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), device=self.device)
|
160 |
-
|
161 |
-
self.scheduler.set_timesteps(num_inference_steps)
|
162 |
-
|
163 |
-
with torch.autocast('cuda'):
|
164 |
-
for i, t in enumerate(self.scheduler.timesteps):
|
165 |
-
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
166 |
-
latent_model_input = torch.cat([latents] * 2)
|
167 |
-
|
168 |
-
# predict the noise residual
|
169 |
-
with torch.no_grad():
|
170 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
171 |
-
|
172 |
-
# perform guidance
|
173 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
174 |
-
noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
175 |
-
|
176 |
-
# compute the previous noisy sample x_t -> x_t-1
|
177 |
-
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
178 |
-
|
179 |
-
return latents
|
180 |
-
|
181 |
-
def decode_latents(self, latents):
|
182 |
-
|
183 |
-
latents = 1 / self.vae.config.scaling_factor * latents
|
184 |
-
|
185 |
-
with torch.no_grad():
|
186 |
-
imgs = self.vae.decode(latents).sample
|
187 |
-
|
188 |
-
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
189 |
-
|
190 |
-
return imgs
|
191 |
-
|
192 |
-
def encode_imgs(self, imgs):
|
193 |
-
# imgs: [B, 3, H, W]
|
194 |
-
|
195 |
-
imgs = 2 * imgs - 1
|
196 |
-
|
197 |
-
posterior = self.vae.encode(imgs).latent_dist
|
198 |
-
latents = posterior.sample() * self.vae.config.scaling_factor
|
199 |
-
|
200 |
-
return latents
|
201 |
-
|
202 |
-
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
203 |
-
|
204 |
-
if isinstance(prompts, str):
|
205 |
-
prompts = [prompts]
|
206 |
-
|
207 |
-
if isinstance(negative_prompts, str):
|
208 |
-
negative_prompts = [negative_prompts]
|
209 |
-
|
210 |
-
# Prompts -> text embeds
|
211 |
-
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
|
212 |
-
|
213 |
-
# Text embeds -> img latents
|
214 |
-
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
215 |
-
|
216 |
-
# Img latents -> imgs
|
217 |
-
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
218 |
-
|
219 |
-
# Img to Numpy
|
220 |
-
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
221 |
-
imgs = (imgs * 255).round().astype('uint8')
|
222 |
-
|
223 |
-
return imgs
|
224 |
-
|
225 |
-
|
226 |
-
if __name__ == '__main__':
|
227 |
-
import argparse
|
228 |
-
import matplotlib.pyplot as plt
|
229 |
-
|
230 |
-
parser = argparse.ArgumentParser()
|
231 |
-
parser.add_argument('prompt', type=str)
|
232 |
-
parser.add_argument('--negative', default='', type=str)
|
233 |
-
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
234 |
-
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
|
235 |
-
parser.add_argument('-H', type=int, default=512)
|
236 |
-
parser.add_argument('-W', type=int, default=512)
|
237 |
-
parser.add_argument('--seed', type=int, default=0)
|
238 |
-
parser.add_argument('--steps', type=int, default=50)
|
239 |
-
opt = parser.parse_args()
|
240 |
-
|
241 |
-
seed_everything(opt.seed)
|
242 |
-
|
243 |
-
device = torch.device('cuda')
|
244 |
-
|
245 |
-
sd = StableDiffusion(device, opt.sd_version, opt.hf_key)
|
246 |
-
|
247 |
-
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
|
248 |
-
|
249 |
-
# visualize image
|
250 |
-
plt.imshow(imgs[0])
|
251 |
-
plt.show()
|
252 |
-
plt.savefig(f'{opt.prompt}.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video3d/diffusion/sd_utils.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
import random
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from ..render.light import DirectionalLight
|
7 |
-
|
8 |
-
def safe_normalize(x, eps=1e-20):
|
9 |
-
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
|
10 |
-
|
11 |
-
def get_view_direction(thetas, phis, overhead, front, phi_offset=0):
|
12 |
-
# phis [B,]; thetas: [B,]
|
13 |
-
# front = 0 [360 - front / 2, front / 2)
|
14 |
-
# side (left) = 1 [front / 2, 180 - front / 2)
|
15 |
-
# back = 2 [180 - front / 2, 180 + front / 2)
|
16 |
-
# side (right) = 3 [180 + front / 2, 360 - front / 2)
|
17 |
-
# top = 4 [0, overhead]
|
18 |
-
# bottom = 5 [180-overhead, 180]
|
19 |
-
res = torch.zeros(thetas.shape[0], dtype=torch.long)
|
20 |
-
|
21 |
-
# first determine by phis
|
22 |
-
phi_offset = np.deg2rad(phi_offset)
|
23 |
-
phis = phis + phi_offset
|
24 |
-
phis = phis % (2 * np.pi)
|
25 |
-
half_front = front / 2
|
26 |
-
|
27 |
-
res[(phis >= (2*np.pi - half_front)) | (phis < half_front)] = 0
|
28 |
-
res[(phis >= half_front) & (phis < (np.pi - half_front))] = 1
|
29 |
-
res[(phis >= (np.pi - half_front)) & (phis < (np.pi + half_front))] = 2
|
30 |
-
res[(phis >= (np.pi + half_front)) & (phis < (2*np.pi - half_front))] = 3
|
31 |
-
|
32 |
-
# override by thetas
|
33 |
-
res[thetas <= overhead] = 4
|
34 |
-
res[thetas >= (np.pi - overhead)] = 5
|
35 |
-
return res
|
36 |
-
|
37 |
-
|
38 |
-
def view_direction_id_to_text(view_direction_id):
|
39 |
-
dir_texts = ['front', 'side', 'back', 'side', 'overhead', 'bottom']
|
40 |
-
return [dir_texts[i] for i in view_direction_id]
|
41 |
-
|
42 |
-
|
43 |
-
def append_text_direction(prompts, dir_texts):
|
44 |
-
return [f'{prompt}, {dir_text} view' for prompt, dir_text in zip(prompts, dir_texts)]
|
45 |
-
|
46 |
-
|
47 |
-
def rand_lights(camera_dir, fixed_ambient, fixed_diffuse):
|
48 |
-
size = camera_dir.shape[0]
|
49 |
-
device = camera_dir.device
|
50 |
-
random_fixed_dir = F.normalize(torch.randn_like(camera_dir) + camera_dir, dim=-1) # Centered around camera_dir
|
51 |
-
random_fixed_intensity = torch.tensor([fixed_ambient, fixed_diffuse], device=device)[None, :].repeat(size, 1) # ambient, diffuse
|
52 |
-
return DirectionalLight(mlp_in=1, mlp_layers=1, mlp_hidden_size=1, # Dummy values
|
53 |
-
intensity_min_max=[0.5, 1],fixed_dir=random_fixed_dir, fixed_intensity=random_fixed_intensity).to(device)
|
54 |
-
|
55 |
-
def rand_poses(size, device, radius_range=[1, 1], theta_range=[0, 120], phi_range=[0, 360], cam_z_offset=10, return_dirs=False, angle_overhead=30, angle_front=60, phi_offset=0, jitter=False, uniform_sphere_rate=0.5):
|
56 |
-
''' generate random poses from an orbit camera
|
57 |
-
Args:
|
58 |
-
size: batch size of generated poses.
|
59 |
-
device: where to allocate the output.
|
60 |
-
radius_range: [min, max]
|
61 |
-
theta_range: [min, max], should be in [0, pi]
|
62 |
-
phi_range: [min, max], should be in [0, 2 * pi]
|
63 |
-
Return:
|
64 |
-
poses: [size, 4, 4]
|
65 |
-
'''
|
66 |
-
|
67 |
-
theta_range = np.deg2rad(theta_range)
|
68 |
-
phi_range = np.deg2rad(phi_range)
|
69 |
-
angle_overhead = np.deg2rad(angle_overhead)
|
70 |
-
angle_front = np.deg2rad(angle_front)
|
71 |
-
|
72 |
-
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
|
73 |
-
|
74 |
-
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
75 |
-
if random.random() < uniform_sphere_rate:
|
76 |
-
# based on http://corysimon.github.io/articles/uniformdistn-on-sphere/
|
77 |
-
# acos takes in [-1, 1], first convert theta range to fit in [-1, 1]
|
78 |
-
theta_range = torch.from_numpy(np.array(theta_range)).to(device)
|
79 |
-
theta_amplitude_range = torch.cos(theta_range)
|
80 |
-
# sample uniformly in amplitude space range
|
81 |
-
thetas_amplitude = torch.rand(size, device=device) * (theta_amplitude_range[1] - theta_amplitude_range[0]) + theta_amplitude_range[0]
|
82 |
-
# convert back
|
83 |
-
thetas = torch.acos(thetas_amplitude)
|
84 |
-
else:
|
85 |
-
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
86 |
-
|
87 |
-
centers = -torch.stack([
|
88 |
-
radius * torch.sin(thetas) * torch.sin(phis),
|
89 |
-
radius * torch.cos(thetas),
|
90 |
-
radius * torch.sin(thetas) * torch.cos(phis),
|
91 |
-
], dim=-1) # [B, 3]
|
92 |
-
|
93 |
-
targets = 0
|
94 |
-
|
95 |
-
# jitters
|
96 |
-
if jitter:
|
97 |
-
centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
|
98 |
-
targets = targets + torch.randn_like(centers) * 0.2
|
99 |
-
|
100 |
-
# lookat
|
101 |
-
forward_vector = safe_normalize(targets - centers)
|
102 |
-
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
103 |
-
right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1))
|
104 |
-
|
105 |
-
if jitter:
|
106 |
-
up_noise = torch.randn_like(up_vector) * 0.02
|
107 |
-
else:
|
108 |
-
up_noise = 0
|
109 |
-
|
110 |
-
up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1) + up_noise)
|
111 |
-
|
112 |
-
poses = torch.stack([right_vector, up_vector, forward_vector], dim=-1)
|
113 |
-
radius = radius[..., None] - cam_z_offset
|
114 |
-
translations = torch.cat([torch.zeros_like(radius), torch.zeros_like(radius), radius], dim=-1)
|
115 |
-
poses = torch.cat([poses.view(-1, 9), translations], dim=-1)
|
116 |
-
|
117 |
-
if return_dirs:
|
118 |
-
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_offset=phi_offset)
|
119 |
-
dirs = view_direction_id_to_text(dirs)
|
120 |
-
else:
|
121 |
-
dirs = None
|
122 |
-
|
123 |
-
return poses, dirs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video3d/diffusion/vsd.py
DELETED
@@ -1,323 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
|
3 |
-
os.environ['HF_HOME'] = '/viscam/u/zzli'
|
4 |
-
|
5 |
-
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
6 |
-
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
|
7 |
-
|
8 |
-
from diffusers.loaders import AttnProcsLayers
|
9 |
-
from diffusers.models.attention_processor import LoRAAttnProcessor
|
10 |
-
from diffusers.models.embeddings import TimestepEmbedding
|
11 |
-
from diffusers.utils.import_utils import is_xformers_available
|
12 |
-
|
13 |
-
# Suppress partial model loading warning
|
14 |
-
logging.set_verbosity_error()
|
15 |
-
|
16 |
-
import gc
|
17 |
-
import random
|
18 |
-
import torch
|
19 |
-
import torch.nn as nn
|
20 |
-
import torch.nn.functional as F
|
21 |
-
import tinycudann as tcnn
|
22 |
-
from video3d.diffusion.sd import StableDiffusion
|
23 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
24 |
-
|
25 |
-
|
26 |
-
def seed_everything(seed):
|
27 |
-
torch.manual_seed(seed)
|
28 |
-
torch.cuda.manual_seed(seed)
|
29 |
-
|
30 |
-
def cleanup():
|
31 |
-
gc.collect()
|
32 |
-
torch.cuda.empty_cache()
|
33 |
-
tcnn.free_temporary_memory()
|
34 |
-
|
35 |
-
class StableDiffusion_VSD(StableDiffusion):
|
36 |
-
def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32, lora_n_timestamp_samples=1):
|
37 |
-
super().__init__(device, sd_version=sd_version, hf_key=hf_key, torch_dtype=torch_dtype)
|
38 |
-
|
39 |
-
# self.device = device
|
40 |
-
# self.sd_version = sd_version
|
41 |
-
# self.torch_dtype = torch_dtype
|
42 |
-
|
43 |
-
if hf_key is not None:
|
44 |
-
print(f'[INFO] using hugging face custom model key: {hf_key}')
|
45 |
-
model_key = hf_key
|
46 |
-
elif self.sd_version == '2.1':
|
47 |
-
model_key = "stabilityai/stable-diffusion-2-1-base"
|
48 |
-
elif self.sd_version == '2.0':
|
49 |
-
model_key = "stabilityai/stable-diffusion-2-base"
|
50 |
-
elif self.sd_version == '1.5':
|
51 |
-
model_key = "runwayml/stable-diffusion-v1-5"
|
52 |
-
else:
|
53 |
-
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
|
54 |
-
|
55 |
-
# # Create model
|
56 |
-
# self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
|
57 |
-
# self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
|
58 |
-
# self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
|
59 |
-
# self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
|
60 |
-
|
61 |
-
# self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
62 |
-
# # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
63 |
-
|
64 |
-
# self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
65 |
-
# self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
66 |
-
|
67 |
-
print(f'[INFO] loading stable diffusion VSD modules...')
|
68 |
-
|
69 |
-
self.unet_lora = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
|
70 |
-
cleanup()
|
71 |
-
|
72 |
-
for p in self.vae.parameters():
|
73 |
-
p.requires_grad_(False)
|
74 |
-
for p in self.text_encoder.parameters():
|
75 |
-
p.requires_grad_(False)
|
76 |
-
for p in self.unet.parameters():
|
77 |
-
p.requires_grad_(False)
|
78 |
-
for p in self.unet_lora.parameters():
|
79 |
-
p.requires_grad_(False)
|
80 |
-
|
81 |
-
# set up LoRA layers
|
82 |
-
lora_attn_procs = {}
|
83 |
-
for name in self.unet_lora.attn_processors.keys():
|
84 |
-
cross_attention_dim = (
|
85 |
-
None
|
86 |
-
if name.endswith("attn1.processor")
|
87 |
-
else self.unet_lora.config.cross_attention_dim
|
88 |
-
)
|
89 |
-
if name.startswith("mid_block"):
|
90 |
-
hidden_size = self.unet_lora.config.block_out_channels[-1]
|
91 |
-
elif name.startswith("up_blocks"):
|
92 |
-
block_id = int(name[len("up_blocks.")])
|
93 |
-
hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
|
94 |
-
block_id
|
95 |
-
]
|
96 |
-
elif name.startswith("down_blocks"):
|
97 |
-
block_id = int(name[len("down_blocks.")])
|
98 |
-
hidden_size = self.unet_lora.config.block_out_channels[block_id]
|
99 |
-
|
100 |
-
lora_attn_procs[name] = LoRAAttnProcessor(
|
101 |
-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
102 |
-
)
|
103 |
-
|
104 |
-
self.unet_lora.set_attn_processor(lora_attn_procs)
|
105 |
-
|
106 |
-
self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
|
107 |
-
self.device
|
108 |
-
)
|
109 |
-
self.lora_layers._load_state_dict_pre_hooks.clear()
|
110 |
-
self.lora_layers._state_dict_hooks.clear()
|
111 |
-
self.lora_n_timestamp_samples = lora_n_timestamp_samples
|
112 |
-
self.scheduler_lora = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
113 |
-
|
114 |
-
print(f'[INFO] loaded stable diffusion VSD modules!')
|
115 |
-
|
116 |
-
def train_lora(
|
117 |
-
self,
|
118 |
-
latents,
|
119 |
-
text_embeddings,
|
120 |
-
camera_condition
|
121 |
-
):
|
122 |
-
B = latents.shape[0]
|
123 |
-
lora_n_timestamp_samples = self.lora_n_timestamp_samples
|
124 |
-
latents = latents.detach().repeat(lora_n_timestamp_samples, 1, 1, 1)
|
125 |
-
|
126 |
-
t = torch.randint(
|
127 |
-
int(self.num_train_timesteps * 0.0),
|
128 |
-
int(self.num_train_timesteps * 1.0),
|
129 |
-
[B * lora_n_timestamp_samples],
|
130 |
-
dtype=torch.long,
|
131 |
-
device=self.device,
|
132 |
-
)
|
133 |
-
|
134 |
-
noise = torch.randn_like(latents)
|
135 |
-
noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
|
136 |
-
if self.scheduler_lora.config.prediction_type == "epsilon":
|
137 |
-
target = noise
|
138 |
-
elif self.scheduler_lora.config.prediction_type == "v_prediction":
|
139 |
-
target = self.scheduler_lora.get_velocity(latents, noise, t)
|
140 |
-
else:
|
141 |
-
raise ValueError(
|
142 |
-
f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
|
143 |
-
)
|
144 |
-
|
145 |
-
# use view-independent text embeddings in LoRA
|
146 |
-
_, text_embeddings_cond = text_embeddings.chunk(2)
|
147 |
-
|
148 |
-
if random.random() < 0.1:
|
149 |
-
camera_condition = torch.zeros_like(camera_condition)
|
150 |
-
|
151 |
-
noise_pred = self.unet_lora(
|
152 |
-
noisy_latents,
|
153 |
-
t,
|
154 |
-
encoder_hidden_states=text_embeddings_cond.repeat(
|
155 |
-
lora_n_timestamp_samples, 1, 1
|
156 |
-
),
|
157 |
-
class_labels=camera_condition.reshape(B, -1).repeat(
|
158 |
-
lora_n_timestamp_samples, 1
|
159 |
-
),
|
160 |
-
cross_attention_kwargs={"scale": 1.0}
|
161 |
-
).sample
|
162 |
-
|
163 |
-
loss_lora = 0.5 * F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
164 |
-
return loss_lora
|
165 |
-
|
166 |
-
|
167 |
-
def train_step(
|
168 |
-
self,
|
169 |
-
text_embeddings,
|
170 |
-
text_embeddings_vd,
|
171 |
-
pred_rgb,
|
172 |
-
camera_condition,
|
173 |
-
im_features,
|
174 |
-
guidance_scale=7.5,
|
175 |
-
guidance_scale_lora=7.5,
|
176 |
-
loss_weight=1.0,
|
177 |
-
min_step_pct=0.02,
|
178 |
-
max_step_pct=0.98,
|
179 |
-
return_aux=False
|
180 |
-
):
|
181 |
-
pred_rgb = pred_rgb.to(self.torch_dtype)
|
182 |
-
text_embeddings = text_embeddings.to(self.torch_dtype)
|
183 |
-
text_embeddings_vd = text_embeddings_vd.to(self.torch_dtype)
|
184 |
-
camera_condition = camera_condition.to(self.torch_dtype)
|
185 |
-
im_features = im_features.to(self.torch_dtype)
|
186 |
-
|
187 |
-
# condition_label = camera_condition
|
188 |
-
condition_label = im_features
|
189 |
-
|
190 |
-
b = pred_rgb.shape[0]
|
191 |
-
|
192 |
-
# interp to 512x512 to be fed into vae.
|
193 |
-
# _t = time.time()
|
194 |
-
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
195 |
-
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
|
196 |
-
|
197 |
-
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
198 |
-
min_step = int(self.num_train_timesteps * min_step_pct)
|
199 |
-
max_step = int(self.num_train_timesteps * max_step_pct)
|
200 |
-
t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
|
201 |
-
|
202 |
-
# encode image into latents with vae, requires grad!
|
203 |
-
# _t = time.time()
|
204 |
-
latents = self.encode_imgs(pred_rgb_512)
|
205 |
-
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
|
206 |
-
|
207 |
-
# predict the noise residual with unet, NO grad!
|
208 |
-
# _t = time.time()
|
209 |
-
with torch.no_grad():
|
210 |
-
# add noise
|
211 |
-
noise = torch.randn_like(latents)
|
212 |
-
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
213 |
-
# pred noise
|
214 |
-
latent_model_input = torch.cat([latents_noisy] * 2)
|
215 |
-
|
216 |
-
# disable unet class embedding here
|
217 |
-
cls_embedding = self.unet.class_embedding
|
218 |
-
self.unet.class_embedding = None
|
219 |
-
|
220 |
-
cross_attention_kwargs = None
|
221 |
-
noise_pred_pretrain = self.unet(
|
222 |
-
latent_model_input,
|
223 |
-
torch.cat([t, t]),
|
224 |
-
encoder_hidden_states=text_embeddings_vd,
|
225 |
-
class_labels=None,
|
226 |
-
cross_attention_kwargs=cross_attention_kwargs
|
227 |
-
).sample
|
228 |
-
|
229 |
-
self.unet.class_embedding = cls_embedding
|
230 |
-
|
231 |
-
# use view-independent text embeddings in LoRA
|
232 |
-
_, text_embeddings_cond = text_embeddings.chunk(2)
|
233 |
-
|
234 |
-
noise_pred_est = self.unet_lora(
|
235 |
-
latent_model_input,
|
236 |
-
torch.cat([t, t]),
|
237 |
-
encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
|
238 |
-
class_labels=torch.cat(
|
239 |
-
[
|
240 |
-
condition_label.reshape(b, -1),
|
241 |
-
torch.zeros_like(condition_label.reshape(b, -1)),
|
242 |
-
],
|
243 |
-
dim=0,
|
244 |
-
),
|
245 |
-
cross_attention_kwargs={"scale": 1.0},
|
246 |
-
).sample
|
247 |
-
|
248 |
-
noise_pred_pretrain_uncond, noise_pred_pretrain_text = noise_pred_pretrain.chunk(2)
|
249 |
-
|
250 |
-
noise_pred_pretrain = noise_pred_pretrain_uncond + guidance_scale * (
|
251 |
-
noise_pred_pretrain_text - noise_pred_pretrain_uncond
|
252 |
-
)
|
253 |
-
|
254 |
-
assert self.scheduler.config.prediction_type == "epsilon"
|
255 |
-
if self.scheduler_lora.config.prediction_type == "v_prediction":
|
256 |
-
alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
|
257 |
-
device=latents_noisy.device, dtype=latents_noisy.dtype
|
258 |
-
)
|
259 |
-
alpha_t = alphas_cumprod[t] ** 0.5
|
260 |
-
sigma_t = (1 - alphas_cumprod[t]) ** 0.5
|
261 |
-
|
262 |
-
noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).reshape(
|
263 |
-
-1, 1, 1, 1
|
264 |
-
) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).reshape(-1, 1, 1, 1)
|
265 |
-
|
266 |
-
noise_pred_est_uncond, noise_pred_est_camera = noise_pred_est.chunk(2)
|
267 |
-
|
268 |
-
noise_pred_est = noise_pred_est_uncond + guidance_scale_lora * (
|
269 |
-
noise_pred_est_camera - noise_pred_est_uncond
|
270 |
-
)
|
271 |
-
|
272 |
-
# w(t), sigma_t^2
|
273 |
-
w = (1 - self.alphas[t])
|
274 |
-
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
275 |
-
grad = loss_weight * w[:, None, None, None] * (noise_pred_pretrain - noise_pred_est)
|
276 |
-
|
277 |
-
grad = torch.nan_to_num(grad)
|
278 |
-
|
279 |
-
targets = (latents - grad).detach()
|
280 |
-
loss_vsd = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
|
281 |
-
|
282 |
-
loss_lora = self.train_lora(latents, text_embeddings, condition_label)
|
283 |
-
|
284 |
-
loss = {
|
285 |
-
'loss_vsd': loss_vsd,
|
286 |
-
'loss_lora': loss_lora
|
287 |
-
}
|
288 |
-
|
289 |
-
if return_aux:
|
290 |
-
aux = {'grad': grad, 't': t, 'w': w}
|
291 |
-
return loss, aux
|
292 |
-
else:
|
293 |
-
return loss
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
if __name__ == '__main__':
|
298 |
-
import argparse
|
299 |
-
import matplotlib.pyplot as plt
|
300 |
-
|
301 |
-
parser = argparse.ArgumentParser()
|
302 |
-
parser.add_argument('prompt', type=str)
|
303 |
-
parser.add_argument('--negative', default='', type=str)
|
304 |
-
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
305 |
-
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
|
306 |
-
parser.add_argument('-H', type=int, default=512)
|
307 |
-
parser.add_argument('-W', type=int, default=512)
|
308 |
-
parser.add_argument('--seed', type=int, default=0)
|
309 |
-
parser.add_argument('--steps', type=int, default=50)
|
310 |
-
opt = parser.parse_args()
|
311 |
-
|
312 |
-
seed_everything(opt.seed)
|
313 |
-
|
314 |
-
device = torch.device('cuda')
|
315 |
-
|
316 |
-
sd = StableDiffusion_VSD(device, opt.sd_version, opt.hf_key)
|
317 |
-
|
318 |
-
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
|
319 |
-
|
320 |
-
# visualize image
|
321 |
-
plt.imshow(imgs[0])
|
322 |
-
plt.show()
|
323 |
-
plt.savefig(f'{opt.prompt}.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video3d/model_ddp.py
CHANGED
@@ -41,10 +41,6 @@ from .render import mesh
|
|
41 |
from .render import light
|
42 |
from .render import render
|
43 |
|
44 |
-
from .diffusion.sd import StableDiffusion
|
45 |
-
from .diffusion.vsd import StableDiffusion_VSD
|
46 |
-
from .diffusion.sd_utils import rand_poses, rand_lights, append_text_direction
|
47 |
-
|
48 |
EPS = 1e-7
|
49 |
|
50 |
|
@@ -1269,53 +1265,8 @@ class Unsup3DDDP:
|
|
1269 |
|
1270 |
self.enable_sds = cfgs.get('enable_sds', False)
|
1271 |
self.enable_vsd = cfgs.get('enable_vsd', False)
|
1272 |
-
|
1273 |
-
|
1274 |
-
|
1275 |
-
# decide if use SDS or VSD
|
1276 |
-
if self.enable_vsd:
|
1277 |
-
# self.stable_diffusion = misc.LazyClass(StableDiffusion_VSD, device=self.device, torch_dtype=diffusion_torch_dtype)
|
1278 |
-
self.stable_diffusion = StableDiffusion_VSD(device=self.device, torch_dtype=diffusion_torch_dtype)
|
1279 |
-
self.diffusion_guidance_scale_lora = cfgs.get('diffusion_guidance_scale_lora', 1.)
|
1280 |
-
self.diffusion_guidance_scale = cfgs.get('diffusion_guidance_scale', 7.5)
|
1281 |
-
else:
|
1282 |
-
self.stable_diffusion = misc.LazyClass(StableDiffusion, device=self.device, torch_dtype=diffusion_torch_dtype)
|
1283 |
-
self.diffusion_guidance_scale = cfgs.get('diffusion_guidance_scale', 100.)
|
1284 |
-
|
1285 |
-
self.diffusion_loss_weight = cfgs.get('diffusion_loss_weight', 1.)
|
1286 |
-
self.diffusion_num_random_cameras = cfgs.get('diffusion_num_random_cameras', 1)
|
1287 |
-
|
1288 |
-
# For prompts
|
1289 |
-
self.diffusion_prompt = cfgs.get('diffusion_prompt', '')
|
1290 |
-
self.diffusion_negative_prompt = cfgs.get('diffusion_negative_prompt', '')
|
1291 |
-
|
1292 |
-
# For image sampling
|
1293 |
-
self.diffusion_albedo_ratio = cfgs.get('diffusion_albedo_ratio', 0.2)
|
1294 |
-
self.diffusion_shading_ratio = cfgs.get('diffusion_shading_ratio', 0.4)
|
1295 |
-
self.diffusion_light_ambient = cfgs.get('diffusion_light_ambient', 0.5)
|
1296 |
-
self.diffusion_light_diffuse = cfgs.get('diffusion_light_diffuse', 0.8)
|
1297 |
-
self.diffusion_radius_range = cfgs.get('diffusion_radius_range', [0.8, 1.4])
|
1298 |
-
self.diffusion_uniform_sphere_rate = cfgs.get('diffusion_uniform_sphere_rate', 0.5)
|
1299 |
-
self.diffusion_theta_range = cfgs.get('diffusion_theta_range', [0, 120])
|
1300 |
-
self.diffusion_phi_offset = cfgs.get('diffusion_phi_offset', 180)
|
1301 |
-
self.diffusion_resolution = cfgs.get('diffusion_resolution', 256)
|
1302 |
-
|
1303 |
-
print('-----------------------------------------------')
|
1304 |
-
print(f"!!!!!! the phi offset for diffusion is set as {self.diffusion_phi_offset}!!!!!!!!!!!!!")
|
1305 |
-
print('-----------------------------------------------')
|
1306 |
-
|
1307 |
-
# For randomizing light
|
1308 |
-
self.diffusion_random_light = cfgs.get('diffusion_random_light', False)
|
1309 |
-
self.diffusion_light_ambient = cfgs.get('diffusion_light_ambient', 0.5)
|
1310 |
-
self.diffusion_light_diffuse = cfgs.get('diffusion_light_diffuse', 0.8)
|
1311 |
-
|
1312 |
-
# For noise scheduling
|
1313 |
-
self.diffusion_max_step = cfgs.get('diffusion_max_step', 0.98)
|
1314 |
-
|
1315 |
-
# For view-dependent prompting
|
1316 |
-
self.diffusion_append_prompt_directions = cfgs.get('diffusion_append_prompt_directions', False)
|
1317 |
-
self.diffusion_angle_overhead = cfgs.get('diffusion_angle_overhead', 30)
|
1318 |
-
self.diffusion_angle_front = cfgs.get('diffusion_angle_front', 60)
|
1319 |
|
1320 |
@staticmethod
|
1321 |
def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False):
|
@@ -2017,141 +1968,6 @@ class Unsup3DDDP:
|
|
2017 |
|
2018 |
return losses, aux
|
2019 |
|
2020 |
-
def score_distillation_sampling(self, shape, texture, resolution, im_features, light, prior_shape, random_light=False, prompts=None, classes_vectors=None, im_features_map=None, w2c_pred=None):
|
2021 |
-
num_instances = im_features.shape[0]
|
2022 |
-
n_total_random_cameras = num_instances * self.diffusion_num_random_cameras
|
2023 |
-
|
2024 |
-
poses, dirs = rand_poses(
|
2025 |
-
n_total_random_cameras, self.device, radius_range=self.diffusion_radius_range, uniform_sphere_rate=self.diffusion_uniform_sphere_rate,
|
2026 |
-
cam_z_offset=self.cam_pos_z_offset, theta_range=self.diffusion_theta_range, phi_offset=self.diffusion_phi_offset, return_dirs=True,
|
2027 |
-
angle_front=self.diffusion_angle_front, angle_overhead=self.diffusion_angle_overhead,
|
2028 |
-
)
|
2029 |
-
mvp, w2c, campos = self.netInstance.get_camera_extrinsics_from_pose(poses, crop_fov_approx=self.crop_fov_approx)
|
2030 |
-
|
2031 |
-
if random_light:
|
2032 |
-
lights = rand_lights(campos, fixed_ambient=self.diffusion_light_ambient, fixed_diffuse=self.diffusion_light_diffuse)
|
2033 |
-
else:
|
2034 |
-
lights = light
|
2035 |
-
|
2036 |
-
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(num_instances, 1, 1).to(self.device)
|
2037 |
-
original_mvp = torch.bmm(proj, w2c_pred)
|
2038 |
-
|
2039 |
-
im_features = im_features.repeat(self.diffusion_num_random_cameras, 1) if im_features is not None else None
|
2040 |
-
num_shapes = shape.v_pos.shape[0]
|
2041 |
-
assert n_total_random_cameras % num_shapes == 0
|
2042 |
-
shape = shape.extend(n_total_random_cameras // num_shapes)
|
2043 |
-
|
2044 |
-
bg_color = torch.rand((n_total_random_cameras, 3), device=self.device) # channel-wise random
|
2045 |
-
background = repeat(bg_color, 'b c -> b h w c', h=resolution[0], w=resolution[1])
|
2046 |
-
|
2047 |
-
# only train the texture
|
2048 |
-
safe_detach = lambda x: x.detach() if x is not None else None
|
2049 |
-
shape = safe_detach(shape)
|
2050 |
-
im_features = safe_detach(im_features)
|
2051 |
-
im_features_map = safe_detach(im_features_map)
|
2052 |
-
|
2053 |
-
set_requires_grad(texture, True)
|
2054 |
-
set_requires_grad(light, True)
|
2055 |
-
|
2056 |
-
image_pred, mask_pred, _, _, albedo, shading = self.render(
|
2057 |
-
shape,
|
2058 |
-
texture,
|
2059 |
-
mvp,
|
2060 |
-
w2c,
|
2061 |
-
campos,
|
2062 |
-
resolution,
|
2063 |
-
im_features=im_features,
|
2064 |
-
light=lights,
|
2065 |
-
prior_shape=prior_shape,
|
2066 |
-
dino_pred=None,
|
2067 |
-
spp=self.renderer_spp,
|
2068 |
-
bg_image=background,
|
2069 |
-
im_features_map={"original_mvp": original_mvp, "im_features_map": im_features_map} if im_features_map is not None else None
|
2070 |
-
)
|
2071 |
-
if self.enable_vsd:
|
2072 |
-
if prompts is None:
|
2073 |
-
prompts = n_total_random_cameras * [self.diffusion_prompt]
|
2074 |
-
else:
|
2075 |
-
if '_' in prompts:
|
2076 |
-
prompts = prompts.replace('_', ' ')
|
2077 |
-
prompts = n_total_random_cameras * [prompts]
|
2078 |
-
|
2079 |
-
prompts = ['a high-resolution DSLR image of ' + x for x in prompts]
|
2080 |
-
assert self.diffusion_append_prompt_directions
|
2081 |
-
# TODO: check if this implementation is aligned with stable-diffusion-prompt-processor
|
2082 |
-
prompts_vd = append_text_direction(prompts, dirs)
|
2083 |
-
negative_prompts = n_total_random_cameras * [self.diffusion_negative_prompt]
|
2084 |
-
|
2085 |
-
text_embeddings = self.stable_diffusion.get_text_embeds(prompts, negative_prompts) # [BB, 77, 768]
|
2086 |
-
text_embeddings_vd = self.stable_diffusion.get_text_embeds(prompts_vd, negative_prompts)
|
2087 |
-
|
2088 |
-
camera_condition_type = 'c2w'
|
2089 |
-
if camera_condition_type == 'c2w':
|
2090 |
-
camera_condition = torch.linalg.inv(w2c).detach()
|
2091 |
-
elif camera_condition_type == 'mvp':
|
2092 |
-
camera_condition = mvp.detach()
|
2093 |
-
else:
|
2094 |
-
raise NotImplementedError
|
2095 |
-
|
2096 |
-
# Alternate among albedo, shading, and image
|
2097 |
-
rand = torch.rand(n_total_random_cameras, device=self.device)
|
2098 |
-
rendered_component = torch.zeros_like(image_pred)
|
2099 |
-
mask_pred = mask_pred[:, None]
|
2100 |
-
background = rearrange(background, 'b h w c -> b c h w')
|
2101 |
-
albedo_flag = rand > (1 - self.diffusion_albedo_ratio)
|
2102 |
-
rendered_component[albedo_flag] = albedo[albedo_flag] * mask_pred[albedo_flag] + (1 - mask_pred[albedo_flag]) * background[albedo_flag]
|
2103 |
-
shading_flag = (rand > (1 - self.diffusion_albedo_ratio - self.diffusion_shading_ratio)) & (rand <= (1 - self.diffusion_albedo_ratio))
|
2104 |
-
rendered_component[shading_flag] = shading.repeat(1, 3, 1, 1)[shading_flag] / 2 * mask_pred[shading_flag] + (1 - mask_pred[shading_flag]) * background[shading_flag]
|
2105 |
-
rendered_component[~(albedo_flag | shading_flag)] = image_pred[~(albedo_flag | shading_flag)]
|
2106 |
-
|
2107 |
-
condition_label = classes_vectors
|
2108 |
-
# condition_label = im_features
|
2109 |
-
|
2110 |
-
sd_loss, sd_aux = self.stable_diffusion.train_step(
|
2111 |
-
text_embeddings,
|
2112 |
-
text_embeddings_vd,
|
2113 |
-
rendered_component,
|
2114 |
-
camera_condition, # TODO: can we input category condition in lora?
|
2115 |
-
condition_label,
|
2116 |
-
guidance_scale=self.diffusion_guidance_scale,
|
2117 |
-
guidance_scale_lora=self.diffusion_guidance_scale_lora,
|
2118 |
-
loss_weight=self.diffusion_loss_weight,
|
2119 |
-
max_step_pct=self.diffusion_max_step,
|
2120 |
-
return_aux=True
|
2121 |
-
)
|
2122 |
-
|
2123 |
-
aux = {'loss': sd_loss['loss_vsd'], 'loss_lora': sd_loss['loss_lora'], 'dirs': dirs, 'sd_aux': sd_aux, 'rendered_shape': shape}
|
2124 |
-
|
2125 |
-
else:
|
2126 |
-
# Prompt to text embeds
|
2127 |
-
if prompts is None:
|
2128 |
-
prompts = n_total_random_cameras * [self.diffusion_prompt]
|
2129 |
-
else:
|
2130 |
-
if '_' in prompts:
|
2131 |
-
prompts = prompts.replace('_', ' ')
|
2132 |
-
prompts = n_total_random_cameras * [prompts]
|
2133 |
-
prompts = ['a high-resolution DSLR image of ' + x for x in prompts]
|
2134 |
-
if self.diffusion_append_prompt_directions:
|
2135 |
-
prompts = append_text_direction(prompts, dirs)
|
2136 |
-
negative_prompts = n_total_random_cameras * [self.diffusion_negative_prompt]
|
2137 |
-
text_embeddings = self.stable_diffusion.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
|
2138 |
-
|
2139 |
-
# Alternate among albedo, shading, and image
|
2140 |
-
rand = torch.rand(n_total_random_cameras, device=self.device)
|
2141 |
-
rendered_component = torch.zeros_like(image_pred)
|
2142 |
-
mask_pred = mask_pred[:, None]
|
2143 |
-
background = rearrange(background, 'b h w c -> b c h w')
|
2144 |
-
albedo_flag = rand > (1 - self.diffusion_albedo_ratio)
|
2145 |
-
rendered_component[albedo_flag] = albedo[albedo_flag] * mask_pred[albedo_flag] + (1 - mask_pred[albedo_flag]) * background[albedo_flag]
|
2146 |
-
shading_flag = (rand > (1 - self.diffusion_albedo_ratio - self.diffusion_shading_ratio)) & (rand <= (1 - self.diffusion_albedo_ratio))
|
2147 |
-
rendered_component[shading_flag] = shading.repeat(1, 3, 1, 1)[shading_flag] / 2 * mask_pred[shading_flag] + (1 - mask_pred[shading_flag]) * background[shading_flag]
|
2148 |
-
rendered_component[~(albedo_flag | shading_flag)] = image_pred[~(albedo_flag | shading_flag)]
|
2149 |
-
sd_loss, sd_aux = self.stable_diffusion.train_step(
|
2150 |
-
text_embeddings, rendered_component, guidance_scale=self.diffusion_guidance_scale, loss_weight=self.diffusion_loss_weight, max_step_pct=self.diffusion_max_step, return_aux=True)
|
2151 |
-
aux = {'loss':sd_loss, 'dirs': dirs, 'sd_aux': sd_aux, 'rendered_shape': shape}
|
2152 |
-
|
2153 |
-
return rendered_component, aux
|
2154 |
-
|
2155 |
def parse_dict_definition(self, dict_config, total_iter):
|
2156 |
'''
|
2157 |
The dict_config is a diction-based configuration with ascending order
|
@@ -2987,19 +2803,6 @@ class Unsup3DDDP:
|
|
2987 |
final_losses[name] = loss.mean()
|
2988 |
final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean()
|
2989 |
|
2990 |
-
## score distillation sampling
|
2991 |
-
sds_random_images = None
|
2992 |
-
if self.enable_sds:
|
2993 |
-
prompts = None
|
2994 |
-
if classes_vectors is not None:
|
2995 |
-
prompts = category_name[0]
|
2996 |
-
sds_random_images, sds_aux = self.score_distillation_sampling(shape, texture, [self.diffusion_resolution, self.diffusion_resolution], im_features, light, prior_shape, prompts=prompts, classes_vectors=class_vector[None, :].expand(batch_size * num_frames, -1), im_features_map=im_features_map, w2c_pred=w2c)
|
2997 |
-
if self.enable_vsd:
|
2998 |
-
final_losses.update({'vsd_loss': sds_aux['loss']})
|
2999 |
-
final_losses.update({'vsd_lora_loss': sds_aux['loss_lora']})
|
3000 |
-
else:
|
3001 |
-
final_losses.update({'sds_loss': sds_aux['loss']})
|
3002 |
-
|
3003 |
## mask distribution loss
|
3004 |
mask_distribution_aux = None
|
3005 |
if self.enable_mask_distribution:
|
|
|
41 |
from .render import light
|
42 |
from .render import render
|
43 |
|
|
|
|
|
|
|
|
|
44 |
EPS = 1e-7
|
45 |
|
46 |
|
|
|
1265 |
|
1266 |
self.enable_sds = cfgs.get('enable_sds', False)
|
1267 |
self.enable_vsd = cfgs.get('enable_vsd', False)
|
1268 |
+
self.enable_sds = False
|
1269 |
+
self.enable_vsd = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1270 |
|
1271 |
@staticmethod
|
1272 |
def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False):
|
|
|
1968 |
|
1969 |
return losses, aux
|
1970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1971 |
def parse_dict_definition(self, dict_config, total_iter):
|
1972 |
'''
|
1973 |
The dict_config is a diction-based configuration with ascending order
|
|
|
2803 |
final_losses[name] = loss.mean()
|
2804 |
final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean()
|
2805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2806 |
## mask distribution loss
|
2807 |
mask_distribution_aux = None
|
2808 |
if self.enable_mask_distribution:
|