OutofFocus / app.py
fffiloni's picture
Upload 12 files
a4737a3 verified
raw
history blame
22.8 kB
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
import torch
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import gc
import gradio as gr
import numpy as np
import os
import pickle
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
import argparse
weights = {
'down': {
4096: 0.0,
1024: 1.0,
256: 1.0,
},
'mid': {
64: 1.0,
},
'up': {
256: 1.0,
1024: 1.0,
4096: 0.0,
}
}
num_inference_steps = 10
model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
should_stop = False
def save_state_to_file(state):
filename = "state.pkl"
with open(filename, 'wb') as f:
pickle.dump(state, f)
return filename
def load_state_from_file(filename):
with open(filename, 'rb') as f:
state = pickle.load(f)
return state
def stop_reconstruct():
global should_stop
should_stop = True
def reconstruct(input_img, caption):
img = input_img
cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((512, 512)),
torchvision.transforms.ToTensor()
])
loaded_image = transform(img).to("cuda").unsqueeze(0)
if loaded_image.shape[1] == 4:
loaded_image = loaded_image[:,:3,:,:]
with torch.no_grad():
encoded_image = pipe.vae.encode(loaded_image*2 - 1)
real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
guidance_scale = 1
inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = inverse_scheduler.timesteps
latents = real_image_latents
inversed_latents = []
with torch.no_grad():
replace_attention_processor(pipe.unet, True)
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
inversed_latents.append(latents)
latent_model_input = torch.cat([latents] * 2)
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds_combined,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# initial state
real_image_initial_latents = latents
W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
QT = nn.Parameter(W_values.clone())
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps
optimizer = torch.optim.AdamW([QT], lr=0.008)
pipe.vae.eval()
pipe.vae.requires_grad_(False)
pipe.unet.eval()
pipe.unet.requires_grad_(False)
last_loss = 1
for epoch in range(50):
gc.collect()
torch.cuda.empty_cache()
if last_loss < 0.02:
break
elif last_loss < 0.03:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.003
elif last_loss < 0.035:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.006
intermediate_values = real_image_initial_latents.clone()
for i in range(num_inference_steps):
latents = intermediate_values.detach().clone()
t = timesteps[i]
prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])
latent_model_input = torch.cat([latents] * 2)
noise_pred_model = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
last_loss = loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
global should_stop
if should_stop:
should_stop = False
break
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
yield image_np, caption, [caption, real_image_initial_latents, QT]
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
yield image_np, caption, [caption, real_image_initial_latents, QT]
class AttnReplaceProcessor(AttnProcessor2_0):
def __init__(self, replace_all, weight):
super().__init__()
self.replace_all = replace_all
self.weight = weight
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
is_cross = not encoder_hidden_states is None
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, _, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))
dimension_squared = hidden_states.shape[1]
if not is_cross and (self.replace_all):
ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst)
ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst)
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
del attention_probs
hidden_states = attn.to_out[0](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def replace_attention_processor(unet, clear = False):
for name, module in unet.named_modules():
if 'attn1' in name and 'to' not in name:
layer_type = name.split('.')[0].split('_')[0]
if not clear:
if layer_type == 'down':
module.processor = AttnReplaceProcessor(True, weights['down'])
elif layer_type == 'mid':
module.processor = AttnReplaceProcessor(True, weights['mid'])
elif layer_type == 'up':
module.processor = AttnReplaceProcessor(True, weights['up'])
else:
module.processor = AttnReplaceProcessor(False, 0.0)
def apply_prompt(meta_data, new_prompt):
caption, real_image_initial_latents, QT = meta_data
inference_steps = len(QT)
cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
# uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
guidance_scale = 7.5
scheduler.set_timesteps(inference_steps, device="cuda")
timesteps = scheduler.timesteps
latents = torch.cat([real_image_initial_latents] * 2)
with torch.no_grad():
replace_attention_processor(pipe.unet)
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
latent_model_input = torch.cat([latents] * 2)
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=modified_prompt_embeds,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
replace_attention_processor(pipe.unet, True)
image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
return image_np
def on_image_change(filepath):
# Extract the filename without extension
filename = os.path.splitext(os.path.basename(filepath))[0]
# Check if the filename is "example1" or "example2"
if filename in ["example1", "example2", "example3", "example4"]:
meta_data_raw = load_state_from_file(f"assets/{filename}.pkl")
_, _, QT_raw = meta_data_raw
global num_inference_steps
num_inference_steps = len(QT_raw)
scale_value = 7
new_prompt = ""
if filename == "example1":
scale_value = 7
new_prompt = "a photo of a tree, summer, colourful"
elif filename == "example2":
scale_value = 8
new_prompt = "a photo of a panda, two ears, white background"
elif filename == "example3":
scale_value = 7
new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
elif filename == "example4":
scale_value = 7
new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"
update_scale(scale_value)
img = apply_prompt(meta_data_raw, new_prompt)
return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value
def update_value(value, key, res):
global weights
weights[key][res] = value
def update_step(value):
global num_inference_steps
num_inference_steps = value
def update_scale(scale):
values = [1.0] * 7
if scale == 9:
return values
reduction_steps = (9 - scale) * 0.5
for i in range(4): # There are 4 positions to reduce symmetrically
if reduction_steps >= 1:
values[i] = 0.0
values[-(i + 1)] = 0.0
reduction_steps -= 1
elif reduction_steps > 0:
values[i] = 0.5
values[-(i + 1)] = 0.5
break
global weights
index = 0
for outer_key, inner_dict in weights.items():
for inner_key in inner_dict:
inner_dict[inner_key] = values[index]
index += 1
return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096]
with gr.Blocks() as demo:
gr.Markdown(
'''
<div style="text-align: center;">
<div style="display: flex; justify-content: center;">
<img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
</div>
<h1>Out of Focus 1.0</h1>
<p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
<a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Ashleigh%20Watson"></a> &ensp;
<a href="https://twitter.com/banterless_ai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Alex%20Nasa"></a>
</div>
'''
)
with gr.Row():
with gr.Column():
with gr.Row():
example_input = gr.Image(height=512, width=512, type="filepath", visible=False)
image_input = gr.Image(height=512, width=512, type="pil", label="Upload Source Image")
steps_slider = gr.Slider(minimum=5, maximum=25, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
reconstruct_button = gr.Button("Reconstruct")
stop_button = gr.Button("Stop", variant="stop", interactive=False)
with gr.Column():
reconstructed_image = gr.Image(type="pil", label="Reconstructed")
with gr.Row():
invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
interpolate_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
with gr.Row():
new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or word addition at the end, achieve the best results by swapping words instead of adding or removing in between")
with gr.Row():
apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
with gr.Row():
with gr.Accordion(label="Advanced Options", open=False):
gr.Markdown(
'''
<div style="text-align: center;">
<h1>Weight Adjustment</h1>
<p style="font-size:16px;">Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)</p>
</div>
'''
)
down_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][4096], label="Self-Attn Down 64x64")
down_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][1024], label="Self-Attn Down 32x32")
down_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][256], label="Self-Attn Down 16x16")
mid_slider_64 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['mid'][64], label="Self-Attn Mid 8x8")
up_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][256], label="Self-Attn Up 16x16")
up_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][1024], label="Self-Attn Up 32x32")
up_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][4096], label="Self-Attn Up 64x64")
with gr.Row():
show_case = gr.Examples(
examples=[
["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background"],
["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful"],
["assets/example2.png", "a photo of a cat, two ears, white background", "a photo of a panda, two ears, white background"],
["assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"],
],
inputs=[example_input, prompt_input, new_prompt_input],
label=None
)
meta_data = gr.State()
example_input.change(
fn=on_image_change,
inputs=example_input,
outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]
).then(
lambda: gr.update(interactive=True),
outputs=apply_button
).then(
lambda: gr.update(interactive=True),
outputs=new_prompt_input
)
steps_slider.release(update_step, inputs=steps_slider)
interpolate_slider.release(update_scale, inputs=interpolate_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
invisible_slider.change(update_scale, inputs=invisible_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
up_slider_4096.change(update_value, inputs=[up_slider_4096, gr.State('up'), gr.State(4096)])
up_slider_1024.change(update_value, inputs=[up_slider_1024, gr.State('up'), gr.State(1024)])
up_slider_256.change(update_value, inputs=[up_slider_256, gr.State('up'), gr.State(256)])
down_slider_4096.change(update_value, inputs=[down_slider_4096, gr.State('down'), gr.State(4096)])
down_slider_1024.change(update_value, inputs=[down_slider_1024, gr.State('down'), gr.State(1024)])
down_slider_256.change(update_value, inputs=[down_slider_256, gr.State('down'), gr.State(256)])
mid_slider_64.change(update_value, inputs=[mid_slider_64, gr.State('mid'), gr.State(64)])
reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, meta_data]).then(
lambda: gr.update(interactive=True),
outputs=reconstruct_button
).then(
lambda: gr.update(interactive=True),
outputs=new_prompt_input
).then(
lambda: gr.update(interactive=True),
outputs=apply_button
).then(
lambda: gr.update(interactive=False),
outputs=stop_button
)
reconstruct_button.click(
lambda: gr.update(interactive=False),
outputs=reconstruct_button
)
reconstruct_button.click(
lambda: gr.update(interactive=True),
outputs=stop_button
)
reconstruct_button.click(
lambda: gr.update(interactive=False),
outputs=apply_button
)
stop_button.click(
lambda: gr.update(interactive=False),
outputs=stop_button
)
apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
stop_button.click(stop_reconstruct)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
demo.queue()
demo.launch(share=args.share)