zhiweili commited on
Commit
304cdbb
1 Parent(s): c823534

add control net

Browse files
Files changed (1) hide show
  1. app_haircolor_inpaint_15.py +44 -5
app_haircolor_inpaint_15.py CHANGED
@@ -10,10 +10,20 @@ from segment_utils import(
10
  restore_result,
11
  )
12
  from diffusers import (
13
- StableDiffusionInpaintPipeline,
 
 
 
14
  EulerAncestralDiscreteScheduler,
15
  )
16
 
 
 
 
 
 
 
 
17
  # BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
18
  BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-inpainting"
19
  # BASE_MODEL = "SG161222/Realistic_Vision_V2.0"
@@ -25,12 +35,34 @@ DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res,
25
 
26
  DEFAULT_CATEGORY = "hair"
27
 
28
- basepipeline = StableDiffusionInpaintPipeline.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  BASE_MODEL,
30
  torch_dtype=torch.float16,
31
  # use_safetensors=True,
 
32
  )
33
-
34
  basepipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(basepipeline.scheduler.config)
35
 
36
  basepipeline = basepipeline.to(DEVICE)
@@ -52,6 +84,11 @@ def image_to_image(
52
  run_task_time = 0
53
  time_cost_str = ''
54
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
 
 
 
55
 
56
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
57
  generated_image = basepipeline(
@@ -60,10 +97,12 @@ def image_to_image(
60
  negative_prompt=DEFAULT_NEGATIVE_PROMPT,
61
  image=input_image,
62
  mask_image=mask_image,
 
63
  height=generate_size,
64
  width=generate_size,
65
  guidance_scale=guidance_scale,
66
  num_inference_steps=num_steps,
 
67
  ).images[0]
68
 
69
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
@@ -103,8 +142,8 @@ def create_demo() -> gr.Blocks:
103
  guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
104
  with gr.Column():
105
  with gr.Accordion("Advanced Options", open=False):
106
- cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1, step=0.1, label="Cond Scale1")
107
- cond_scale2 = gr.Slider(minimum=0, maximum=3, value=0.6, step=0.1, label="Cond Scale2")
108
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
109
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
110
  seed = gr.Number(label="Seed", value=8)
 
10
  restore_result,
11
  )
12
  from diffusers import (
13
+ StableDiffusionControlNetInpaintPipeline,
14
+ ControlNetModel,
15
+ DDIMScheduler,
16
+ DPMSolverMultistepScheduler,
17
  EulerAncestralDiscreteScheduler,
18
  )
19
 
20
+ from controlnet_aux import (
21
+ CannyDetector,
22
+ LineartDetector,
23
+ PidiNetDetector,
24
+ HEDdetector,
25
+ )
26
+
27
  # BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
28
  BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-inpainting"
29
  # BASE_MODEL = "SG161222/Realistic_Vision_V2.0"
 
35
 
36
  DEFAULT_CATEGORY = "hair"
37
 
38
+ canny_detector = CannyDetector()
39
+ lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
40
+ lineart_detector = lineart_detector.to(DEVICE)
41
+
42
+ pidiNet_detector = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
43
+ pidiNet_detector = pidiNet_detector.to(DEVICE)
44
+
45
+ hed_detector = HEDdetector.from_pretrained('lllyasviel/Annotators')
46
+ hed_detector = hed_detector.to(DEVICE)
47
+
48
+ controlnet = [
49
+ ControlNetModel.from_pretrained(
50
+ "lllyasviel/control_v11p_sd15_lineart",
51
+ torch_dtype=torch.float16,
52
+ ),
53
+ ControlNetModel.from_pretrained(
54
+ "lllyasviel/control_v11p_sd15_softedge",
55
+ torch_dtype=torch.float16,
56
+ ),
57
+ ]
58
+
59
+ basepipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
60
  BASE_MODEL,
61
  torch_dtype=torch.float16,
62
  # use_safetensors=True,
63
+ controlnet=controlnet,
64
  )
65
+ # basepipeline.scheduler = DDIMScheduler.from_config(basepipeline.scheduler.config)
66
  basepipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(basepipeline.scheduler.config)
67
 
68
  basepipeline = basepipeline.to(DEVICE)
 
84
  run_task_time = 0
85
  time_cost_str = ''
86
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
87
+ # canny_image = canny_detector(input_image, int(generate_size*1), generate_size)
88
+ lineart_image = lineart_detector(input_image, 384, generate_size)
89
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
90
+ pidiNet_image = pidiNet_detector(input_image, 512, generate_size)
91
+ control_image = [lineart_image, pidiNet_image]
92
 
93
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
94
  generated_image = basepipeline(
 
97
  negative_prompt=DEFAULT_NEGATIVE_PROMPT,
98
  image=input_image,
99
  mask_image=mask_image,
100
+ control_image=control_image,
101
  height=generate_size,
102
  width=generate_size,
103
  guidance_scale=guidance_scale,
104
  num_inference_steps=num_steps,
105
+ controlnet_conditioning_scale=[cond_scale1, cond_scale2],
106
  ).images[0]
107
 
108
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
142
  guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
143
  with gr.Column():
144
  with gr.Accordion("Advanced Options", open=False):
145
+ cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond Scale1")
146
+ cond_scale2 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond Scale2")
147
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
148
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
149
  seed = gr.Number(label="Seed", value=8)