File size: 3,935 Bytes
95a9f0f
 
 
 
 
 
292ed4d
3638fca
95a9f0f
 
0d1c424
95a9f0f
e9422b8
95a9f0f
 
0d1c424
95a9f0f
 
e9422b8
 
 
 
0d1c424
63a0180
cf6c5a9
95a9f0f
 
 
 
3638fca
4b7b010
3638fca
 
 
 
 
 
 
 
95a9f0f
 
 
 
 
 
 
 
 
 
 
63a0180
 
 
3638fca
 
 
95a9f0f
f817fc9
 
 
 
 
95a9f0f
 
 
 
 
 
4a09384
95a9f0f
fb8ab9c
 
4e385d7
 
fb8ab9c
 
95a9f0f
 
41bd23c
95a9f0f
63a0180
0f13dc2
 
63a0180
efd6737
95a9f0f
 
 
6852b3e
63a0180
95a9f0f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from PIL import Image
import torch
import numpy as np
import cv2
import gradio as gr
from torchvision import transforms 

controlnet = ControlNetModel.from_pretrained(
    "briaai/BRIA-2.2-ControlNet-Canny",
    torch_dtype=torch.float16
).to('cuda')

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "briaai/BRIA-2.2",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    device_map='auto',
    low_cpu_mem_usage=True,
    offload_state_dict=True,
).to('cuda')
# pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
pipe.enable_xformers_memory_efficient_attention()
pipe.force_zeros_for_empty_prompt = False

low_threshold = 100
high_threshold = 200

def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
    resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
    return resized_image

def get_canny_filter(image):
    
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    canny_image = Image.fromarray(image)
    return canny_image

def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
    generator = torch.manual_seed(seed)
    
    # resize input_image to 1024x1024
    input_image = resize_image(input_image)
    
    canny_image = get_canny_filter(input_image)
  
    images = pipe(
        prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        generator=generator,
        ).images

    return [canny_image,images[0]]
    
block = gr.Blocks().queue()

with block:
    gr.Markdown("## BRIA 2.2 ControlNet Canny")
    gr.HTML('''
      <p style="margin-bottom: 10px; font-size: 94%">
        This is a demo for ControlNet Canny that using
        <a href="https://huggingface.co/briaai/BRIA-2.2" target="_blank">BRIA 2.2 text-to-image model</a> as backbone. 
        Trained on licensed data, BRIA 2.2 provide full legal liability coverage for copyright and privacy infringement.
      </p>
    ''')
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
            num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
            controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
            seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
            run_button = gr.Button(value="Run")
            
            
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
    ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])

block.launch(debug = True)