ysmao commited on
Commit
4342954
1 Parent(s): e3e1936

add layout controlnet

Browse files
Files changed (5) hide show
  1. annotator/dsine_hub.py +37 -0
  2. annotator/midas.py +34 -0
  3. annotator/upernet.py +190 -0
  4. annotator/util.py +38 -0
  5. app.py +147 -4
annotator/dsine_hub.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ class NormalDetector:
7
+ def __init__(self):
8
+ self.model_path = "hugoycj/DSINE-hub"
9
+ self.dsine = torch.hub.load(self.model_path, "DSINE", trust_repo=True)
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ @torch.no_grad()
13
+ def __call__(self, image):
14
+ self.dsine.model.to(self.device)
15
+ self.dsine.model.pixel_coords = self.dsine.model.pixel_coords.to(self.device)
16
+ H, W, C = image.shape
17
+
18
+ normal = self.dsine.infer_pil(image)[0] # Output shape: (H, W, 3)
19
+ normal = (normal + 1.0) / 2.0 # Convert values to the range [0, 1]
20
+ normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
21
+ normal_img = Image.fromarray(normal).resize((W, H))
22
+
23
+ self.dsine.model.to("cpu")
24
+ self.dsine.model.pixel_coords = self.dsine.model.pixel_coords.to("cpu")
25
+ return normal_img
26
+
27
+
28
+ if __name__ == "__main__":
29
+ from diffusers.utils import load_image
30
+
31
+ image = load_image(
32
+ "https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg"
33
+ )
34
+ image = np.asarray(image)
35
+ normal_detector = NormalDetector()
36
+ normal_image = normal_detector(image)
37
+ normal_image.save("normal_image.jpg")
annotator/midas.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import DPTFeatureExtractor
5
+ from transformers import DPTForDepthEstimation
6
+
7
+
8
+ class DepthDetector:
9
+ def __init__(self, model_path=None):
10
+ if model_path is not None:
11
+ self.model_path = model_path
12
+ else:
13
+ self.model_path = "Intel/dpt-hybrid-midas"
14
+ self.model = DPTForDepthEstimation.from_pretrained(self.model_path)
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.feature_extractor = DPTFeatureExtractor.from_pretrained(self.model_path)
17
+
18
+ @torch.no_grad()
19
+ def __call__(self, image):
20
+ self.model.to(self.device)
21
+ H, W, C = image.shape
22
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
23
+ inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
24
+ outputs = self.model(**inputs)
25
+ predicted_depth = outputs.predicted_depth
26
+ outputs = predicted_depth.squeeze().cpu().numpy()
27
+ if len(outputs.shape) == 2:
28
+ output = outputs[np.newaxis, :, :]
29
+ else:
30
+ output = outputs
31
+ formatted = (output * 255 / np.max(output)).astype("uint8")
32
+ depth_image = Image.fromarray(formatted[0, ...]).resize((W, H))
33
+ self.model.to("cpu")
34
+ return depth_image
annotator/upernet.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import AutoImageProcessor
5
+ from transformers import UperNetForSemanticSegmentation
6
+
7
+
8
+ class SegmDetector:
9
+ def __init__(self, model_path=None):
10
+ if model_path is not None:
11
+ self.model_path = model_path
12
+ else:
13
+ self.model_path = "openmmlab/upernet-convnext-small"
14
+ self.model = UperNetForSemanticSegmentation.from_pretrained(self.model_path)
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.feature_extractor = AutoImageProcessor.from_pretrained(self.model_path)
17
+ self.palette = [
18
+ [120, 120, 120],
19
+ [180, 120, 120],
20
+ [6, 230, 230],
21
+ [80, 50, 50],
22
+ [4, 200, 3],
23
+ [120, 120, 80],
24
+ [140, 140, 140],
25
+ [204, 5, 255],
26
+ [230, 230, 230],
27
+ [4, 250, 7],
28
+ [224, 5, 255],
29
+ [235, 255, 7],
30
+ [150, 5, 61],
31
+ [120, 120, 70],
32
+ [8, 255, 51],
33
+ [255, 6, 82],
34
+ [143, 255, 140],
35
+ [204, 255, 4],
36
+ [255, 51, 7],
37
+ [204, 70, 3],
38
+ [0, 102, 200],
39
+ [61, 230, 250],
40
+ [255, 6, 51],
41
+ [11, 102, 255],
42
+ [255, 7, 71],
43
+ [255, 9, 224],
44
+ [9, 7, 230],
45
+ [220, 220, 220],
46
+ [255, 9, 92],
47
+ [112, 9, 255],
48
+ [8, 255, 214],
49
+ [7, 255, 224],
50
+ [255, 184, 6],
51
+ [10, 255, 71],
52
+ [255, 41, 10],
53
+ [7, 255, 255],
54
+ [224, 255, 8],
55
+ [102, 8, 255],
56
+ [255, 61, 6],
57
+ [255, 194, 7],
58
+ [255, 122, 8],
59
+ [0, 255, 20],
60
+ [255, 8, 41],
61
+ [255, 5, 153],
62
+ [6, 51, 255],
63
+ [235, 12, 255],
64
+ [160, 150, 20],
65
+ [0, 163, 255],
66
+ [140, 140, 140],
67
+ [250, 10, 15],
68
+ [20, 255, 0],
69
+ [31, 255, 0],
70
+ [255, 31, 0],
71
+ [255, 224, 0],
72
+ [153, 255, 0],
73
+ [0, 0, 255],
74
+ [255, 71, 0],
75
+ [0, 235, 255],
76
+ [0, 173, 255],
77
+ [31, 0, 255],
78
+ [11, 200, 200],
79
+ [255, 82, 0],
80
+ [0, 255, 245],
81
+ [0, 61, 255],
82
+ [0, 255, 112],
83
+ [0, 255, 133],
84
+ [255, 0, 0],
85
+ [255, 163, 0],
86
+ [255, 102, 0],
87
+ [194, 255, 0],
88
+ [0, 143, 255],
89
+ [51, 255, 0],
90
+ [0, 82, 255],
91
+ [0, 255, 41],
92
+ [0, 255, 173],
93
+ [10, 0, 255],
94
+ [173, 255, 0],
95
+ [0, 255, 153],
96
+ [255, 92, 0],
97
+ [255, 0, 255],
98
+ [255, 0, 245],
99
+ [255, 0, 102],
100
+ [255, 173, 0],
101
+ [255, 0, 20],
102
+ [255, 184, 184],
103
+ [0, 31, 255],
104
+ [0, 255, 61],
105
+ [0, 71, 255],
106
+ [255, 0, 204],
107
+ [0, 255, 194],
108
+ [0, 255, 82],
109
+ [0, 10, 255],
110
+ [0, 112, 255],
111
+ [51, 0, 255],
112
+ [0, 194, 255],
113
+ [0, 122, 255],
114
+ [0, 255, 163],
115
+ [255, 153, 0],
116
+ [0, 255, 10],
117
+ [255, 112, 0],
118
+ [143, 255, 0],
119
+ [82, 0, 255],
120
+ [163, 255, 0],
121
+ [255, 235, 0],
122
+ [8, 184, 170],
123
+ [133, 0, 255],
124
+ [0, 255, 92],
125
+ [184, 0, 255],
126
+ [255, 0, 31],
127
+ [0, 184, 255],
128
+ [0, 214, 255],
129
+ [255, 0, 112],
130
+ [92, 255, 0],
131
+ [0, 224, 255],
132
+ [112, 224, 255],
133
+ [70, 184, 160],
134
+ [163, 0, 255],
135
+ [153, 0, 255],
136
+ [71, 255, 0],
137
+ [255, 0, 163],
138
+ [255, 204, 0],
139
+ [255, 0, 143],
140
+ [0, 255, 235],
141
+ [133, 255, 0],
142
+ [255, 0, 235],
143
+ [245, 0, 255],
144
+ [255, 0, 122],
145
+ [255, 245, 0],
146
+ [10, 190, 212],
147
+ [214, 255, 0],
148
+ [0, 204, 255],
149
+ [20, 0, 255],
150
+ [255, 255, 0],
151
+ [0, 153, 255],
152
+ [0, 41, 255],
153
+ [0, 255, 204],
154
+ [41, 0, 255],
155
+ [41, 255, 0],
156
+ [173, 0, 255],
157
+ [0, 245, 255],
158
+ [71, 0, 255],
159
+ [122, 0, 255],
160
+ [0, 255, 184],
161
+ [0, 92, 255],
162
+ [184, 255, 0],
163
+ [0, 133, 255],
164
+ [255, 214, 0],
165
+ [25, 194, 194],
166
+ [102, 255, 0],
167
+ [92, 0, 255],
168
+ ]
169
+
170
+ @torch.no_grad()
171
+ def __call__(self, image):
172
+ self.model.to(self.device)
173
+ H, W, C = image.shape
174
+
175
+ pixel_values = self.feature_extractor(
176
+ images=image, return_tensors="pt"
177
+ ).pixel_values
178
+ pixel_values = pixel_values.to(self.device)
179
+ outputs = self.model(pixel_values)
180
+ segm_image = self.feature_extractor.post_process_semantic_segmentation(outputs)
181
+ segm_image = segm_image[0].cpu()
182
+ color_seg = np.zeros(
183
+ (segm_image.shape[0], segm_image.shape[1], 3), dtype=np.uint8
184
+ )
185
+ for label, color in enumerate(self.palette):
186
+ color_seg[segm_image == label, :] = color
187
+ color_seg = color_seg.astype(np.uint8)
188
+ segm_image = Image.fromarray(color_seg).resize((W, H))
189
+ self.model.to("cpu")
190
+ return segm_image
annotator/util.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ def HWC3(x):
6
+ assert x.dtype == np.uint8
7
+ if x.ndim == 2:
8
+ x = x[:, :, None]
9
+ assert x.ndim == 3
10
+ H, W, C = x.shape
11
+ assert C == 1 or C == 3 or C == 4
12
+ if C == 3:
13
+ return x
14
+ if C == 1:
15
+ return np.concatenate([x, x, x], axis=2)
16
+ if C == 4:
17
+ color = x[:, :, 0:3].astype(np.float32)
18
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
19
+ y = color * alpha + 255.0 * (1.0 - alpha)
20
+ y = y.clip(0, 255).astype(np.uint8)
21
+ return y
22
+
23
+
24
+ def resize_image(input_image, resolution):
25
+ H, W, C = input_image.shape
26
+ H = float(H)
27
+ W = float(W)
28
+ k = float(resolution) / max(H, W)
29
+ H *= k
30
+ W *= k
31
+ H = int(np.round(H / 64.0)) * 64
32
+ W = int(np.round(W / 64.0)) * 64
33
+ img = cv2.resize(
34
+ input_image,
35
+ (W, H),
36
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
37
+ )
38
+ return img
app.py CHANGED
@@ -1,7 +1,150 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import numpy as np
4
+ from diffusers import (
5
+ ControlNetModel,
6
+ StableDiffusionControlNetPipeline,
7
+ UniPCMultistepScheduler,
8
+ )
9
  import gradio as gr
10
 
11
+ from annotator.util import resize_image, HWC3
12
+ from annotator.midas import DepthDetector
13
+ from annotator.dsine_hub import NormalDetector
14
+ from annotator.upernet import SegmDetector
15
 
16
+ controlnet_checkpoint = "kujiale-ai/controlnet"
17
+ # Initialize pipeline
18
+ controlnet = ControlNetModel.from_pretrained(
19
+ controlnet_checkpoint,
20
+ subfolder="control_v1_sd15_layout_fp16",
21
+ torch_dtype=torch.float16,
22
+ )
23
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
24
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
25
+ ).to("cuda")
26
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
27
+
28
+ apply_depth = DepthDetector()
29
+ apply_normal = NormalDetector()
30
+ apply_segm = SegmDetector()
31
+
32
+
33
+ @spaces.GPU(duration=10)
34
+ def generate(
35
+ input_image,
36
+ prompt,
37
+ a_prompt,
38
+ n_prompt,
39
+ num_samples,
40
+ image_resolution,
41
+ steps,
42
+ strength,
43
+ guidance_scale,
44
+ seed,
45
+ ):
46
+ color_image = resize_image(HWC3(input_image), image_resolution)
47
+ # set seed
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+
51
+ with torch.no_grad():
52
+ depth_image = apply_depth(color_image)
53
+ normal_image = apply_normal(color_image)
54
+ segm_image = apply_segm(color_image)
55
+
56
+ # Prepare Layout Control Image
57
+ depth_image = np.array(depth_image, dtype=np.float32) / 255.0
58
+ depth_image = torch.from_numpy(depth_image[:, :, None])[None].permute(
59
+ 0, 3, 1, 2
60
+ )
61
+ normal_image = np.array(normal_image, dtype=np.float32)
62
+ normal_image = normal_image / 127.5 - 1.0
63
+ normal_image = torch.from_numpy(normal_image)[None].permute(0, 3, 1, 2)
64
+ segm_image = np.array(segm_image, dtype=np.float32) / 255.0
65
+ segm_image = torch.from_numpy(segm_image)[None].permute(0, 3, 1, 2)
66
+ control_image = torch.cat([depth_image, normal_image, segm_image], dim=1)
67
+
68
+ generator = torch.Generator(device="cuda").manual_seed(seed)
69
+ images = pipe(
70
+ prompt + a_prompt,
71
+ negative_prompt=n_prompt,
72
+ num_images_per_prompt=num_samples,
73
+ num_inference_steps=steps,
74
+ image=control_image,
75
+ generator=generator,
76
+ guidance_scale=guidance_scale,
77
+ controlnet_conditioning_scale=strength,
78
+ ).images
79
+ return images
80
+
81
+
82
+ block = gr.Blocks().queue()
83
+ with block:
84
+ with gr.Row():
85
+ gr.Markdown("## KuJiaLe Layout ControlNet Demo")
86
+ with gr.Row():
87
+ input_image = gr.Image(source="upload", type="numpy", label="input_image")
88
+ with gr.Row():
89
+ prompt = gr.Textbox(label="Prompt")
90
+ with gr.Row():
91
+ run_button = gr.Button(label="Run")
92
+ with gr.Row():
93
+ with gr.Column():
94
+ with gr.Accordion("Advanced options", open=False):
95
+ num_samples = gr.Slider(
96
+ label="Images", minimum=1, maximum=2, value=1, step=1
97
+ )
98
+ image_resolution = gr.Slider(
99
+ label="Image Resolution",
100
+ minimum=512,
101
+ maximum=768,
102
+ value=768,
103
+ step=64,
104
+ )
105
+ strength = gr.Slider(
106
+ label="Control Strength",
107
+ minimum=0.0,
108
+ maximum=2.0,
109
+ value=1,
110
+ step=0.1,
111
+ )
112
+ steps = gr.Slider(
113
+ label="Steps", minimum=1, maximum=50, value=25, step=1
114
+ )
115
+ guidance_scale = gr.Slider(
116
+ label="Guidance Scale",
117
+ minimum=0.1,
118
+ maximum=20.0,
119
+ value=7.5,
120
+ step=0.1,
121
+ )
122
+ seed = gr.Slider(
123
+ label="Seed", minimum=-1, maximum=2147483647, value=1, step=1
124
+ )
125
+ a_prompt = gr.Textbox(
126
+ label="Added Prompt", value="best quality, extremely detailed"
127
+ )
128
+ n_prompt = gr.Textbox(
129
+ label="Negative Prompt",
130
+ value="longbody, lowres, bad anatomy, human, extra digit, fewer digits, cropped, worst quality, low quality",
131
+ )
132
+
133
+ with gr.Row():
134
+ image_gallery = gr.Gallery(
135
+ label="Output", show_label=False, elem_id="gallery"
136
+ ).style(grid=1, height="auto")
137
+
138
+ ips = [
139
+ input_image,
140
+ prompt,
141
+ a_prompt,
142
+ n_prompt,
143
+ num_samples,
144
+ image_resolution,
145
+ steps,
146
+ strength,
147
+ guidance_scale,
148
+ seed,
149
+ ]
150
+ run_button.click(fn=generate, inputs=ips, outputs=[image_gallery])