Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
from pytorch_lightning import seed_everything | |
from tqdm import tqdm | |
from lib.utils.iimage import IImage | |
from lib.smplfusion import share, router, attentionpatch, transformerpatch | |
from lib.smplfusion.patches.attentionpatch import painta | |
from lib.utils import tokenize | |
verbose = False | |
def init_painta(token_idx): | |
# Initialize painta | |
router.attention_forward = attentionpatch.painta.forward | |
router.basic_transformer_forward = transformerpatch.painta.forward | |
painta.painta_on = True | |
painta.painta_res = [16, 32] | |
painta.token_idx = token_idx | |
def run( | |
ddim, | |
method, | |
prompt, | |
image, | |
mask, | |
seed, | |
eta, | |
prefix, | |
negative_prompt, | |
positive_prompt, | |
dt, | |
guidance_scale | |
): | |
# Text condition | |
context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt]) | |
token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>'))) | |
token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')] | |
# Setup painta if needed | |
if 'painta' in method: init_painta(token_idx) | |
else: router.reset() | |
# Image condition | |
unet_condition = ddim.get_inpainting_condition(image, mask) | |
dtype = unet_condition.dtype | |
share.set_mask(mask) | |
# Starting latent | |
seed_everything(seed) | |
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype) | |
# Turn off gradients | |
ddim.unet.requires_grad_(False) | |
pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) | |
for timestep in share.DDIMIterator(pbar): | |
if share.timestep <= 500: router.reset() | |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) | |
with torch.autocast('cuda'): | |
eps_uncond, eps = ddim.unet( | |
torch.cat([_zt, _zt]).to(dtype), | |
timesteps = torch.tensor([timestep, timestep]).cuda(), | |
context = context | |
).chunk(2) | |
eps = (eps_uncond + guidance_scale * (eps - eps_uncond)) | |
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep] | |
zt = share.schedule.sqrt_alphas[timestep - dt] * z0 + share.schedule.sqrt_one_minus_alphas[timestep - dt] * eps | |
with torch.no_grad(): | |
output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor)) | |
return output_image | |