MasaCtrl / gradio_app /image_synthesis_app.py
Ljzycmd's picture
Update gradio_app/image_synthesis_app.py
47d853f verified
import gradio as gr
import numpy as np
import torch
from diffusers import DDIMScheduler
from pytorch_lightning import seed_everything
from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import (AttentionBase,
regiter_attention_editor_diffusers)
from .app_utils import global_context
torch.set_grad_enabled(False)
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
# "cpu")
# model_path = "andite/anything-v4.0"
# scheduler = DDIMScheduler(beta_start=0.00085,
# beta_end=0.012,
# beta_schedule="scaled_linear",
# clip_sample=False,
# set_alpha_to_one=False)
# model = MasaCtrlPipeline.from_pretrained(model_path,
# scheduler=scheduler).to(device)
def consistent_synthesis(source_prompt, target_prompt, starting_step,
starting_layer, image_resolution, ddim_steps, scale,
seed, appended_prompt, negative_prompt):
from masactrl.masactrl import MutualSelfAttentionControl
model = global_context["model"]
device = global_context["device"]
seed_everything(seed)
with torch.no_grad():
if appended_prompt is not None:
source_prompt += appended_prompt
target_prompt += appended_prompt
prompts = [source_prompt, target_prompt]
# initialize the noise map
start_code = torch.randn([1, 4, 64, 64], device=device)
start_code = start_code.expand(len(prompts), -1, -1, -1)
# inference the synthesized image without MasaCtrl
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
target_image_ori = model([target_prompt],
latents=start_code[-1:],
guidance_scale=7.5)
target_image_ori = target_image_ori.cpu().permute(0, 2, 3, 1).numpy()
# inference the synthesized image with MasaCtrl
# hijack the attention module
controller = MutualSelfAttentionControl(starting_step, starting_layer)
regiter_attention_editor_diffusers(model, controller)
# inference the synthesized image
image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)
image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
return [image_masactrl[0], target_image_ori[0],
image_masactrl[1]] # source, fixed seed, masactrl
def create_demo_synthesis():
with gr.Blocks() as demo:
gr.Markdown("## **Input Settings**")
with gr.Row():
with gr.Column():
source_prompt = gr.Textbox(
label="Source Prompt",
value='1boy, casual, outdoors, sitting',
interactive=True)
target_prompt = gr.Textbox(
label="Target Prompt",
value='1boy, casual, outdoors, standing',
interactive=True)
with gr.Row():
ddim_steps = gr.Slider(label="DDIM Steps",
minimum=1,
maximum=999,
value=50,
step=1)
starting_step = gr.Slider(
label="Step of MasaCtrl",
minimum=0,
maximum=999,
value=4,
step=1)
starting_layer = gr.Slider(label="Layer of MasaCtrl",
minimum=0,
maximum=16,
value=10,
step=1)
run_btn = gr.Button("Run")
with gr.Column():
appended_prompt = gr.Textbox(label="Appended Prompt", value='')
negative_prompt = gr.Textbox(label="Negative Prompt", value='')
with gr.Row():
image_resolution = gr.Slider(label="Image Resolution",
minimum=256,
maximum=768,
value=512,
step=64)
scale = gr.Slider(label="CFG Scale",
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
seed = gr.Slider(label="Seed",
minimum=-1,
maximum=2147483647,
value=42,
step=1)
gr.Markdown("## **Output**")
with gr.Row():
image_source = gr.Image(label="Source Image")
image_fixed = gr.Image(label="Image with Fixed Seed")
image_masactrl = gr.Image(label="Image with MasaCtrl")
inputs = [
source_prompt, target_prompt, starting_step, starting_layer,
image_resolution, ddim_steps, scale, seed, appended_prompt,
negative_prompt
]
run_btn.click(consistent_synthesis, inputs,
[image_source, image_fixed, image_masactrl])
gr.Examples(
[[
"1boy, bishounen, casual, indoors, sitting, coffee shop, bokeh",
"1boy, bishounen, casual, indoors, standing, coffee shop, bokeh",
42
],
[
"1boy, casual, outdoors, sitting",
"1boy, casual, outdoors, sitting, side view", 42
],
[
"1boy, casual, outdoors, sitting",
"1boy, casual, outdoors, standing, clapping hands", 42
],
[
"1boy, casual, outdoors, sitting",
"1boy, casual, outdoors, sitting, shows thumbs up", 42
],
[
"1boy, casual, outdoors, sitting",
"1boy, casual, outdoors, sitting, with crossed arms", 42
],
[
"1boy, casual, outdoors, sitting",
"1boy, casual, outdoors, sitting, rasing hands", 42
]],
[source_prompt, target_prompt, seed],
)
return demo
if __name__ == "__main__":
demo_syntehsis = create_demo_synthesis()
demo_synthesis.launch()