Spaces:
Runtime error
Runtime error
# demo inspired by https://huggingface.co/spaces/lambdalabs/image-mixer-demo | |
import argparse | |
import copy | |
import gradio as gr | |
import torch | |
from functools import partial | |
from itertools import chain | |
from torch import autocast | |
from pytorch_lightning import seed_everything | |
from basicsr.utils import tensor2img | |
from ldm.inference_base import DEFAULT_NEGATIVE_PROMPT, diffusion_inference, get_adapters, get_sd_models | |
from ldm.modules.extra_condition import api | |
from ldm.modules.extra_condition.api import ExtraCondition, get_cond_model | |
from ldm.modules.encoders.adapter import CoAdapterFuser | |
import os | |
from huggingface_hub import hf_hub_url | |
import subprocess | |
import shlex | |
torch.set_grad_enabled(False) | |
urls = { | |
'TencentARC/T2I-Adapter':[ | |
'third-party-models/body_pose_model.pth', 'third-party-models/table5_pidinet.pth', | |
'models/coadapter-canny-sd15v1.pth', | |
'models/coadapter-color-sd15v1.pth', | |
'models/coadapter-sketch-sd15v1.pth', | |
'models/coadapter-style-sd15v1.pth', | |
'models/coadapter-depth-sd15v1.pth', | |
'models/coadapter-fuser-sd15v1.pth', | |
], | |
'runwayml/stable-diffusion-v1-5': ['v1-5-pruned-emaonly.ckpt'], | |
'andite/anything-v4.0': ['anything-v4.5-pruned.ckpt', 'anything-v4.0.vae.pt'], | |
} | |
if os.path.exists('models') == False: | |
os.mkdir('models') | |
for repo in urls: | |
files = urls[repo] | |
for file in files: | |
url = hf_hub_url(repo, file) | |
name_ckp = url.split('/')[-1] | |
save_path = os.path.join('models',name_ckp) | |
if os.path.exists(save_path) == False: | |
subprocess.run(shlex.split(f'wget {url} -O {save_path}')) | |
supported_cond = ['style', 'color', 'sketch', 'depth', 'canny'] | |
# config | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--sd_ckpt', | |
type=str, | |
default='models/v1-5-pruned-emaonly.ckpt', | |
help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported', | |
) | |
parser.add_argument( | |
'--vae_ckpt', | |
type=str, | |
default=None, | |
help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded', | |
) | |
global_opt = parser.parse_args() | |
global_opt.config = 'configs/stable-diffusion/sd-v1-inference.yaml' | |
for cond_name in supported_cond: | |
setattr(global_opt, f'{cond_name}_adapter_ckpt', f'models/coadapter-{cond_name}-sd15v1.pth') | |
global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
global_opt.max_resolution = 512 * 512 | |
global_opt.sampler = 'ddim' | |
global_opt.cond_weight = 1.0 | |
global_opt.C = 4 | |
global_opt.f = 8 | |
#TODO: expose style_cond_tau to users | |
global_opt.style_cond_tau = 1.0 | |
# stable-diffusion model | |
sd_model, sampler = get_sd_models(global_opt) | |
# adapters and models to processing condition inputs | |
adapters = {} | |
cond_models = {} | |
torch.cuda.empty_cache() | |
# fuser is indispensable | |
coadapter_fuser = CoAdapterFuser(unet_channels=[320, 640, 1280, 1280], width=768, num_head=8, n_layes=3) | |
coadapter_fuser.load_state_dict(torch.load(f'models/coadapter-fuser-sd15v1.pth')) | |
coadapter_fuser = coadapter_fuser.to(global_opt.device) | |
def run(*args): | |
with torch.inference_mode(), \ | |
sd_model.ema_scope(), \ | |
autocast('cuda'): | |
inps = [] | |
for i in range(0, len(args) - 8, len(supported_cond)): | |
inps.append(args[i:i + len(supported_cond)]) | |
opt = copy.deepcopy(global_opt) | |
opt.prompt, opt.neg_prompt, opt.scale, opt.n_samples, opt.seed, opt.steps, opt.resize_short_edge, opt.cond_tau \ | |
= args[-8:] | |
conds = [] | |
activated_conds = [] | |
for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)): | |
cond_name = supported_cond[idx] | |
if b == 'Nothing': | |
if cond_name in adapters: | |
adapters[cond_name]['model'] = adapters[cond_name]['model'].cpu() | |
else: | |
activated_conds.append(cond_name) | |
if cond_name in adapters: | |
adapters[cond_name]['model'] = adapters[cond_name]['model'].to(opt.device) | |
else: | |
adapters[cond_name] = get_adapters(opt, getattr(ExtraCondition, cond_name)) | |
adapters[cond_name]['cond_weight'] = cond_weight | |
process_cond_module = getattr(api, f'get_cond_{cond_name}') | |
if b == 'Image': | |
if cond_name not in cond_models: | |
cond_models[cond_name] = get_cond_model(opt, getattr(ExtraCondition, cond_name)) | |
conds.append(process_cond_module(opt, im1, 'image', cond_models[cond_name])) | |
else: | |
conds.append(process_cond_module(opt, im2, cond_name, None)) | |
features = dict() | |
for idx, cond_name in enumerate(activated_conds): | |
cur_feats = adapters[cond_name]['model'](conds[idx]) | |
if isinstance(cur_feats, list): | |
for i in range(len(cur_feats)): | |
cur_feats[i] *= adapters[cond_name]['cond_weight'] | |
else: | |
cur_feats *= adapters[cond_name]['cond_weight'] | |
features[cond_name] = cur_feats | |
adapter_features, append_to_context = coadapter_fuser(features) | |
output_conds = [] | |
for cond in conds: | |
output_conds.append(tensor2img(cond, rgb2bgr=False)) | |
ims = [] | |
seed_everything(opt.seed) | |
for _ in range(opt.n_samples): | |
result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context) | |
ims.append(tensor2img(result, rgb2bgr=False)) | |
# Clear GPU memory cache so less likely to OOM | |
torch.cuda.empty_cache() | |
return ims, output_conds | |
def change_visible(im1, im2, val): | |
outputs = {} | |
if val == "Image": | |
outputs[im1] = gr.update(visible=True) | |
outputs[im2] = gr.update(visible=False) | |
elif val == "Nothing": | |
outputs[im1] = gr.update(visible=False) | |
outputs[im2] = gr.update(visible=False) | |
else: | |
outputs[im1] = gr.update(visible=False) | |
outputs[im2] = gr.update(visible=True) | |
return outputs | |
DESCRIPTION = '''# CoAdapter | |
[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter) | |
This gradio demo is for a simple experience of CoAdapter: | |
''' | |
with gr.Blocks(title="CoAdapter", css=".gr-box {border-color: #8136e2}") as demo: | |
gr.Markdown(DESCRIPTION) | |
btns = [] | |
ims1 = [] | |
ims2 = [] | |
cond_weights = [] | |
with gr.Row(): | |
for cond_name in supported_cond: | |
with gr.Box(): | |
with gr.Column(): | |
btn1 = gr.Radio( | |
choices=["Image", cond_name, "Nothing"], | |
label=f"Input type for {cond_name}", | |
interactive=True, | |
value="Nothing", | |
) | |
im1 = gr.Image(source='upload', label="Image", interactive=True, visible=False, type="numpy") | |
im2 = gr.Image(source='upload', label=cond_name, interactive=True, visible=False, type="numpy") | |
cond_weight = gr.Slider( | |
label="Condition weight", minimum=0, maximum=5, step=0.05, value=1, interactive=True) | |
fn = partial(change_visible, im1, im2) | |
btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False) | |
btns.append(btn1) | |
ims1.append(im1) | |
ims2.append(im2) | |
cond_weights.append(cond_weight) | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT) | |
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", value=7.5, minimum=1, maximum=20, step=0.1) | |
n_samples = gr.Slider(label="Num samples", value=1, minimum=1, maximum=8, step=1) | |
seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1) | |
steps = gr.Slider(label="Steps", value=50, minimum=10, maximum=100, step=1) | |
resize_short_edge = gr.Slider(label="Image resolution", value=512, minimum=320, maximum=1024, step=1) | |
cond_tau = gr.Slider( | |
label="timestamp parameter that determines until which step the adapter is applied", | |
value=1.0, | |
minimum=0.1, | |
maximum=1.0, | |
step=0.05) | |
with gr.Row(): | |
submit = gr.Button("Generate") | |
output = gr.Gallery().style(grid=2, height='auto') | |
cond = gr.Gallery().style(grid=2, height='auto') | |
inps = list(chain(btns, ims1, ims2, cond_weights)) | |
inps.extend([prompt, neg_prompt, scale, n_samples, seed, steps, resize_short_edge, cond_tau]) | |
submit.click(fn=run, inputs=inps, outputs=[output, cond]) | |
# demo.launch() | |
demo.queue().launch(debug=True, server_name='0.0.0.0') | |