OutofFocus / app.py
fffiloni's picture
Update app.py
538194b verified
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
import torch
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import gc
import gradio as gr
import numpy as np
import os
import pickle
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
import argparse
weights = {
'down': {
4096: 0.0,
1024: 1.0,
256: 1.0,
},
'mid': {
64: 1.0,
},
'up': {
256: 1.0,
1024: 1.0,
4096: 0.0,
}
}
num_inference_steps = 10
model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
should_stop = False
def save_state_to_file(state):
filename = "state.pkl"
with open(filename, 'wb') as f:
pickle.dump(state, f)
return filename
def load_state_from_file(filename):
with open(filename, 'rb') as f:
state = pickle.load(f)
return state
def stop_reconstruct():
global should_stop
should_stop = True
def reconstruct(input_img, caption):
img = input_img
cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((512, 512)),
torchvision.transforms.ToTensor()
])
loaded_image = transform(img).to("cuda").unsqueeze(0)
if loaded_image.shape[1] == 4:
loaded_image = loaded_image[:,:3,:,:]
with torch.no_grad():
encoded_image = pipe.vae.encode(loaded_image*2 - 1)
real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
guidance_scale = 1
inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = inverse_scheduler.timesteps
latents = real_image_latents
inversed_latents = []
with torch.no_grad():
replace_attention_processor(pipe.unet, True)
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
inversed_latents.append(latents)
latent_model_input = torch.cat([latents] * 2)
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds_combined,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# initial state
real_image_initial_latents = latents
W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
QT = nn.Parameter(W_values.clone())
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps
optimizer = torch.optim.AdamW([QT], lr=0.008)
pipe.vae.eval()
pipe.vae.requires_grad_(False)
pipe.unet.eval()
pipe.unet.requires_grad_(False)
last_loss = 1
for epoch in range(50):
gc.collect()
torch.cuda.empty_cache()
if last_loss < 0.02:
break
elif last_loss < 0.03:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.003
elif last_loss < 0.035:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.006
intermediate_values = real_image_initial_latents.clone()
for i in range(num_inference_steps):
latents = intermediate_values.detach().clone()
t = timesteps[i]
prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])
latent_model_input = torch.cat([latents] * 2)
noise_pred_model = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
last_loss = loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
global should_stop
if should_stop:
should_stop = False
break
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
yield image_np, caption, [caption, real_image_initial_latents, QT]
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
yield image_np, caption, [caption, real_image_initial_latents, QT]
class AttnReplaceProcessor(AttnProcessor2_0):
def __init__(self, replace_all, weight):
super().__init__()
self.replace_all = replace_all
self.weight = weight
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
is_cross = not encoder_hidden_states is None
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, _, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))
dimension_squared = hidden_states.shape[1]
if not is_cross and (self.replace_all):
ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst)
ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst)
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
del attention_probs
hidden_states = attn.to_out[0](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def replace_attention_processor(unet, clear = False):
for name, module in unet.named_modules():
if 'attn1' in name and 'to' not in name:
layer_type = name.split('.')[0].split('_')[0]
if not clear:
if layer_type == 'down':
module.processor = AttnReplaceProcessor(True, weights['down'])
elif layer_type == 'mid':
module.processor = AttnReplaceProcessor(True, weights['mid'])
elif layer_type == 'up':
module.processor = AttnReplaceProcessor(True, weights['up'])
else:
module.processor = AttnReplaceProcessor(False, 0.0)
def apply_prompt(meta_data, new_prompt):
caption, real_image_initial_latents, QT = meta_data
inference_steps = len(QT)
cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
# uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
guidance_scale = 7.5
scheduler.set_timesteps(inference_steps, device="cuda")
timesteps = scheduler.timesteps
latents = torch.cat([real_image_initial_latents] * 2)
with torch.no_grad():
replace_attention_processor(pipe.unet)
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
latent_model_input = torch.cat([latents] * 2)
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=modified_prompt_embeds,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
replace_attention_processor(pipe.unet, True)
image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
return image_np
def on_image_change(filepath):
# Extract the filename without extension
filename = os.path.splitext(os.path.basename(filepath))[0]
# Check if the filename is "example1" or "example2"
if filename in ["example1", "example2", "example3", "example4"]:
meta_data_raw = load_state_from_file(f"assets/{filename}.pkl")
_, _, QT_raw = meta_data_raw
global num_inference_steps
num_inference_steps = len(QT_raw)
scale_value = 7
new_prompt = ""
if filename == "example1":
scale_value = 7
new_prompt = "a photo of a tree, summer, colourful"
elif filename == "example2":
scale_value = 8
new_prompt = "a photo of a panda, two ears, white background"
elif filename == "example3":
scale_value = 7
new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
elif filename == "example4":
scale_value = 7
new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"
update_scale(scale_value)
img = apply_prompt(meta_data_raw, new_prompt)
return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value
def update_value(value, key, res):
global weights
weights[key][res] = value
def update_step(value):
global num_inference_steps
num_inference_steps = value
def update_scale(scale):
values = [1.0] * 7
if scale == 9:
return values
reduction_steps = (9 - scale) * 0.5
for i in range(4): # There are 4 positions to reduce symmetrically
if reduction_steps >= 1:
values[i] = 0.0
values[-(i + 1)] = 0.0
reduction_steps -= 1
elif reduction_steps > 0:
values[i] = 0.5
values[-(i + 1)] = 0.5
break
global weights
index = 0
for outer_key, inner_dict in weights.items():
for inner_key in inner_dict:
inner_dict[inner_key] = values[index]
index += 1
return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096]
with gr.Blocks() as demo:
gr.Markdown(
'''
<div style="text-align: center;">
<div style="display: flex; justify-content: center;">
<img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
</div>
<h1>Out of Focus 1.0</h1>
<p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
<a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Ashleigh%20Watson"></a> &ensp;
<a href="https://twitter.com/banterless_ai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Alex%20Nasa"></a>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<p style="display: flex;gap: 6px;">
<a href="https://huggingface.co/spaces/fffiloni/OutofFocus?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space">
</a> to skip the queue and enjoy faster inference on the GPU of your choice
</p>
</div>
'''
)
with gr.Row():
with gr.Column():
with gr.Row():
example_input = gr.Image(height=512, width=512, type="filepath", visible=False)
image_input = gr.Image(height=512, width=512, type="pil", label="Upload Source Image")
steps_slider = gr.Slider(minimum=5, maximum=25, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
reconstruct_button = gr.Button("Reconstruct")
stop_button = gr.Button("Stop", variant="stop", interactive=False)
with gr.Column():
reconstructed_image = gr.Image(type="pil", label="Reconstructed")
with gr.Row():
invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
interpolate_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
with gr.Row():
new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or word addition at the end, achieve the best results by swapping words instead of adding or removing in between")
with gr.Row():
apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
with gr.Row():
with gr.Accordion(label="Advanced Options", open=False):
gr.Markdown(
'''
<div style="text-align: center;">
<h1>Weight Adjustment</h1>
<p style="font-size:16px;">Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)</p>
</div>
'''
)
down_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][4096], label="Self-Attn Down 64x64")
down_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][1024], label="Self-Attn Down 32x32")
down_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][256], label="Self-Attn Down 16x16")
mid_slider_64 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['mid'][64], label="Self-Attn Mid 8x8")
up_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][256], label="Self-Attn Up 16x16")
up_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][1024], label="Self-Attn Up 32x32")
up_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][4096], label="Self-Attn Up 64x64")
with gr.Row():
show_case = gr.Examples(
examples=[
["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background"],
["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful"],
["assets/example2.png", "a photo of a cat, two ears, white background", "a photo of a panda, two ears, white background"],
["assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"],
],
inputs=[example_input, prompt_input, new_prompt_input],
label=None
)
meta_data = gr.State()
example_input.change(
fn=on_image_change,
inputs=example_input,
outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]
).then(
lambda: gr.update(interactive=True),
outputs=apply_button
).then(
lambda: gr.update(interactive=True),
outputs=new_prompt_input
)
steps_slider.release(update_step, inputs=steps_slider)
interpolate_slider.release(update_scale, inputs=interpolate_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
invisible_slider.change(update_scale, inputs=invisible_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
up_slider_4096.change(update_value, inputs=[up_slider_4096, gr.State('up'), gr.State(4096)])
up_slider_1024.change(update_value, inputs=[up_slider_1024, gr.State('up'), gr.State(1024)])
up_slider_256.change(update_value, inputs=[up_slider_256, gr.State('up'), gr.State(256)])
down_slider_4096.change(update_value, inputs=[down_slider_4096, gr.State('down'), gr.State(4096)])
down_slider_1024.change(update_value, inputs=[down_slider_1024, gr.State('down'), gr.State(1024)])
down_slider_256.change(update_value, inputs=[down_slider_256, gr.State('down'), gr.State(256)])
mid_slider_64.change(update_value, inputs=[mid_slider_64, gr.State('mid'), gr.State(64)])
reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, meta_data]).then(
lambda: gr.update(interactive=True),
outputs=reconstruct_button
).then(
lambda: gr.update(interactive=True),
outputs=new_prompt_input
).then(
lambda: gr.update(interactive=True),
outputs=apply_button
).then(
lambda: gr.update(interactive=False),
outputs=stop_button
)
reconstruct_button.click(
lambda: gr.update(interactive=False),
outputs=reconstruct_button
)
reconstruct_button.click(
lambda: gr.update(interactive=True),
outputs=stop_button
)
reconstruct_button.click(
lambda: gr.update(interactive=False),
outputs=apply_button
)
stop_button.click(
lambda: gr.update(interactive=False),
outputs=stop_button
)
apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
stop_button.click(stop_reconstruct)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
demo.queue()
demo.launch(share=args.share)