Spaces:
Runtime error
Runtime error
File size: 4,222 Bytes
f3f94c7 0bd9523 f3f94c7 9a97b4c f3f94c7 96be33e e2ebf36 f3f94c7 9a97b4c f3f94c7 124ff35 f3f94c7 16871fc f3f94c7 16871fc f3f94c7 54c0124 f3f94c7 12000db f3f94c7 419164d f3f94c7 9b1a334 419164d f3f94c7 12000db 54c0124 61f7e82 12000db f3f94c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import spaces
import gradio as gr
import torch
from diffusers import (
AutoencoderKL,
EulerAncestralDiscreteScheduler,
)
from diffusers.utils import load_image
from replace_bg.model.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from replace_bg.model.controlnet import ControlNetModel
from replace_bg.utilities import resize_image, remove_bg_from_image, paste_fg_over_image, get_control_image_tensor
controlnet = ControlNetModel.from_pretrained("briaai/BRIA-2.3-ControlNet-BG-Gen", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet, torch_dtype=torch.float16, vae=vae).to('cuda:0')
pipe.scheduler = EulerAncestralDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
steps_offset=1
)
@spaces.GPU
def generate_(prompt, negative_prompt, control_tensor, num_steps, controlnet_conditioning_scale, seed):
generator = torch.Generator("cuda").manual_seed(seed)
gen_img = pipe(
negative_prompt=negative_prompt,
prompt=prompt,
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
num_inference_steps=num_steps,
image = control_tensor,
generator=generator
).images[0]
return gen_img
@spaces.GPU
def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
image = resize_image(input_image)
mask = remove_bg_from_image(image)
control_tensor = get_control_image_tensor(pipe.vae, image, mask)
gen_image = generate_(prompt, negative_prompt, control_tensor, num_steps, controlnet_conditioning_scale, seed)
result_image = paste_fg_over_image(gen_image, image, mask)
return result_image
block = gr.Blocks().queue()
with block:
gr.Markdown("## BRIA Background Generation")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This is a demo for ControlNet background generation that using
<a href="briaai/BRIA-2.3-ControlNet-BG-Gen" target="_blank">BRIA 2.3 text-to-image model</a> as backbone.
Trained on licensed data, BRIA 2.3 provide full legal liability coverage for copyright and privacy infringement.
</p>
''')
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', type="pil", label="Upload", elem_id="image_upload", height=600) # None for upload, ctrl+v and webcam
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
num_steps = gr.Slider(label="Number of steps", minimum=10, maximum=100, value=30, step=1)
controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
run_button = gr.Button(value="Generate")
with gr.Column():
result_gallery = gr.Image(label='Output', type="pil", show_label=True, elem_id="output-img")
# result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height=600)
ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
gr.Examples(
examples=[
["./example1.png"],
["./example2.png"],
["./example3.png"],
["./example4.png"],
],
fn=process,
inputs=[input_image],
cache_examples=False,
)
block.launch(debug = True) |