Cognomen's picture
update wd checkpoint to flax
b485002
raw
history blame
3.3 kB
import gradio as gr
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
from diffusers import UniPCMultistepScheduler
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torchvision
import torchvision.transforms as T
from flax.jax_utils import replicate
from flax.training.common_utils import shard
#from torchvision.transforms import v2 as T2
import cv2
import PIL
from PIL import Image
import numpy as np
import torchvision.transforms.functional as F
output_res = (768,768)
conditioning_image_transforms = T.Compose(
[
#T2.ScaleJitter(target_size=output_res, scale_range=(0.5, 3.0))),
T.RandomCrop(size=output_res, pad_if_needed=True, padding_mode="symmetric"),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
cnet, cnet_params = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"./models/wd-1-5-b2-flax",
controlnet=cnet,
revision="flax",
dtype=jnp.bfloat16,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe.enable_xformers_memory_efficient_attention()
def get_random(seed):
jax.random.PRNGKey(seed)
# inference function takes prompt, negative prompt and image
def infer(prompt, negative_prompt, image):
# implement your inference function here
params["controlnet"] = cnet_params
num_samples = 1
inp = Image.fromarray(image)
cond_input = conditioning_image_transforms(inp)
cond_input = T.ToPILImage()(cond_input)
cond_img_in = pipe.prepare_image_inputs([cond_input] * num_samples)
prompt_in = pipe.prepare_text_inputs([prompt] * num_samples)
prompt_in = shard(prompt_in)
n_prompt_in = pipe.prepare_text_inputs([negative_prompt] * num_samples)
n_prompt_in = shard(n_prompt_in)
rng = get_random(0)
rng.random.split(rng, jax.device_count())
p_params = replicate(params)
output = pipe(
prompt_ids=prompts_in,
image=cond_img_in,
prng_seed=rng,
neg_prompt_ids=n_prompt_in,
num_inference_steps=20,
jit=True
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
gr.Interface(
infer,
inputs=[
gr.Textbox(
label="Enter prompt",
max_lines=1,
placeholder="1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck",
),
gr.Textbox(
label="Enter negative prompt",
max_lines=1,
placeholder="low quality",
),
gr.Image(),
],
outputs=gr.Gallery().style(grid=[2], height="auto"),
title="Generate controlled outputs with Categorical Conditioning on Waifu Diffusion 1.5 beta 2.",
description="This Space uses image examples as style conditioning.",
examples=[["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "low quality", "wikipe_cond_1.png"]],
allow_flagging=False,
).launch(enable_queue=True)