File size: 5,421 Bytes
566c306
0842498
566c306
0842498
566c306
 
0842498
566c306
 
 
 
0842498
566c306
0842498
566c306
0842498
566c306
 
0842498
566c306
 
 
 
 
 
 
0842498
566c306
 
0842498
 
 
 
923b86d
0842498
 
f6d84d3
566c306
 
 
 
 
 
 
 
 
 
 
 
a6ab172
566c306
 
 
 
 
 
 
 
a6ab172
566c306
 
 
 
 
 
 
 
810c96b
566c306
37e1426
566c306
 
 
 
 
 
 
 
 
 
 
 
 
 
95a9f0f
566c306
 
95a9f0f
566c306
 
 
 
 
 
a9d82c4
 
24591c6
a9d82c4
 
 
 
 
 
90147a1
566c306
63a0180
566c306
 
3638fca
566c306
f817fc9
83f458c
95a9f0f
566c306
83f458c
a9d82c4
95a9f0f
566c306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95a9f0f
 
566c306
 
 
 
95a9f0f
566c306
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python

# from __future__ import annotations

# import gradio as gr
# import torch

# from app_canny import create_demo as create_demo_canny
# # from app_depth import create_demo as create_demo_depth
# # from app_recoloring import create_demo as create_demo_recoloring
# from model import Model

# DESCRIPTION = "# BRIA 2.2 ControlNets"

# model = Model(base_model_id='briaai/BRIA-2.2', task_name="Canny")

# with gr.Blocks(css="style.css") as demo:
#     gr.Markdown(DESCRIPTION)

#     with gr.Tabs():
#         with gr.TabItem("Canny"):
#             create_demo_canny(model.process_canny)
#         # with gr.TabItem("Depth (Future)"):
#         #     create_demo_canny(model.process_mlsd)
#         # with gr.TabItem("Recoloring (Future)"):
#         #     create_demo_canny(model.process_scribble)

# if __name__ == "__main__":
#     demo.queue(max_size=20).launch()




################################################################


import spaces
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
from PIL import Image
import torch
import numpy as np
import cv2
import gradio as gr
from torchvision import transforms 

controlnet = ControlNetModel.from_pretrained(
    "briaai/BRIA-2.2-ControlNet-Canny",
    torch_dtype=torch.float16
) #.to('cuda')

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "briaai/BRIA-2.2",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    device_map='auto',
    low_cpu_mem_usage=True,
    offload_state_dict=True,
) #.to('cuda')
pipe.scheduler = EulerAncestralDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
    steps_offset=1
)
# pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
# pipe.enable_xformers_memory_efficient_attention()
pipe.force_zeros_for_empty_prompt = False
pipe.to('cuda').to(torch.float16)
low_threshold = 100
high_threshold = 200

def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
    resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
    return resized_image

def get_canny_filter(image):
    
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    canny_image = Image.fromarray(image)
    return canny_image

@spaces.GPU
def generate_(prompt, negative_prompt, canny_image, num_steps, controlnet_conditioning_scale, seed):
    generator = torch.Generator("cuda").manual_seed(seed)
    images = pipe(
    prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
    generator=generator,
    ).images
    return images

@spaces.GPU
def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
    
    # resize input_image to 1024x1024
    input_image = resize_image(input_image)
    
    canny_image = get_canny_filter(input_image)
  
    images = generate_(prompt, negative_prompt, canny_image, num_steps, controlnet_conditioning_scale, seed)

    return [canny_image,images[0]]


    
block = gr.Blocks().queue()

with block:
    gr.Markdown("## BRIA 2.2 ControlNet Canny")
    gr.HTML('''
      <p style="margin-bottom: 10px; font-size: 94%">
        This is a demo for ControlNet Canny that using
        <a href="https://huggingface.co/briaai/BRIA-2.2" target="_blank">BRIA 2.2 text-to-image model</a> as backbone. 
        Trained on licensed data, BRIA 2.2 provide full legal liability coverage for copyright and privacy infringement.
      </p>
    ''')
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
            num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
            controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
            seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
            run_button = gr.Button(value="Run")
            
            
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
    ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])

block.launch(debug = True)