zhiweili commited on
Commit
991068d
1 Parent(s): 991954d

change to img2img

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_haircolor_img2img.py +40 -37
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_haircolor_inpaint_15 import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_haircolor_img2img import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_haircolor_img2img.py CHANGED
@@ -10,22 +10,19 @@ from segment_utils import(
10
  restore_result,
11
  )
12
  from diffusers import (
13
- StableDiffusionControlNetImg2ImgPipeline,
14
- ControlNetModel,
15
- DDIMScheduler,
16
- DPMSolverMultistepScheduler,
17
  EulerAncestralDiscreteScheduler,
18
- UniPCMultistepScheduler,
19
  )
20
 
21
  from controlnet_aux import (
22
  CannyDetector,
23
  LineartDetector,
24
- PidiNetDetector,
25
- HEDdetector,
26
  )
27
 
28
- BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
@@ -37,32 +34,34 @@ canny_detector = CannyDetector()
37
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
38
  lineart_detector = lineart_detector.to(DEVICE)
39
 
40
- pidiNet_detector = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
41
- pidiNet_detector = pidiNet_detector.to(DEVICE)
42
 
43
- hed_detector = HEDdetector.from_pretrained('lllyasviel/Annotators')
44
- hed_detector = hed_detector.to(DEVICE)
45
-
46
- controlnet = [
47
- ControlNetModel.from_pretrained(
48
- "lllyasviel/control_v11p_sd15_lineart",
49
- torch_dtype=torch.float16,
50
- ),
51
- ControlNetModel.from_pretrained(
52
- "lllyasviel/control_v11p_sd15_softedge",
53
- torch_dtype=torch.float16,
54
- ),
55
- ]
 
 
56
 
57
- basepipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
58
  BASE_MODEL,
59
- torch_dtype=torch.float16,
 
60
  use_safetensors=True,
61
- controlnet=controlnet,
 
 
 
62
  )
63
 
64
- basepipeline.scheduler = UniPCMultistepScheduler.from_config(basepipeline.scheduler.config)
65
-
66
  basepipeline = basepipeline.to(DEVICE)
67
 
68
  basepipeline.enable_model_cpu_offload()
@@ -78,15 +77,17 @@ def image_to_image(
78
  generate_size: int,
79
  cond_scale1: float = 1.2,
80
  cond_scale2: float = 1.2,
 
 
81
  ):
82
  run_task_time = 0
83
  time_cost_str = ''
84
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
85
- lineart_image = lineart_detector(input_image, 768, generate_size)
 
86
 
87
- pidinet_image = pidiNet_detector(input_image, 768, generate_size)
88
-
89
- cond_image = [lineart_image, pidinet_image]
90
 
91
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
92
  generated_image = basepipeline(
@@ -99,7 +100,7 @@ def image_to_image(
99
  guidance_scale=guidance_scale,
100
  strength=strength,
101
  num_inference_steps=num_steps,
102
- controlnet_conditioning_scale=[cond_scale1, cond_scale2],
103
  ).images[0]
104
 
105
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
@@ -137,15 +138,17 @@ def create_demo() -> gr.Blocks:
137
  with gr.Column():
138
  num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
139
  guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
140
- strength = gr.Slider(minimum=0, maximum=3, value=0.2, step=0.1, label="Strength")
141
  with gr.Column():
 
142
  with gr.Accordion("Advanced Options", open=False):
143
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
144
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
145
  seed = gr.Number(label="Seed", value=8)
146
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
147
- cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale1")
148
- cond_scale2 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond_scale2")
 
 
149
  g_btn = gr.Button("Edit Image")
150
 
151
  with gr.Row():
@@ -164,7 +167,7 @@ def create_demo() -> gr.Blocks:
164
  outputs=[origin_area_image, croper],
165
  ).success(
166
  fn=image_to_image,
167
- inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, strength, generate_size, cond_scale1, cond_scale2],
168
  outputs=[generated_image, generated_cost],
169
  ).success(
170
  fn=restore_result,
 
10
  restore_result,
11
  )
12
  from diffusers import (
13
+ DiffusionPipeline,
14
+ T2IAdapter,
15
+ MultiAdapter,
16
+ AutoencoderKL,
17
  EulerAncestralDiscreteScheduler,
 
18
  )
19
 
20
  from controlnet_aux import (
21
  CannyDetector,
22
  LineartDetector,
 
 
23
  )
24
 
25
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
26
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
34
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
35
  lineart_detector = lineart_detector.to(DEVICE)
36
 
 
 
37
 
38
+ adapters = MultiAdapter(
39
+ [
40
+ T2IAdapter.from_pretrained(
41
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0",
42
+ torch_dtype=torch.float16,
43
+ varient="fp16",
44
+ ),
45
+ T2IAdapter.from_pretrained(
46
+ "TencentARC/t2i-adapter-canny-sdxl-1.0",
47
+ torch_dtype=torch.float16,
48
+ varient="fp16",
49
+ ),
50
+ ]
51
+ )
52
+ adapters = adapters.to(torch.float16)
53
 
54
+ basepipeline = DiffusionPipeline.from_pretrained(
55
  BASE_MODEL,
56
+ torch_dtype=torch.float16,
57
+ variant="fp16",
58
  use_safetensors=True,
59
+ vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
60
+ scheduler=EulerAncestralDiscreteScheduler.from_pretrained(BASE_MODEL, subfolder="scheduler"),
61
+ adapter=adapters,
62
+ custom_pipeline="./pipelines/pipeline_sdxl_adapter_img2img.py",
63
  )
64
 
 
 
65
  basepipeline = basepipeline.to(DEVICE)
66
 
67
  basepipeline.enable_model_cpu_offload()
 
77
  generate_size: int,
78
  cond_scale1: float = 1.2,
79
  cond_scale2: float = 1.2,
80
+ lineart_detect:float = 0.375,
81
+ canny_detect:float = 0.375,
82
  ):
83
  run_task_time = 0
84
  time_cost_str = ''
85
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
86
+ lineart_image = lineart_detector(input_image, int(generate_size * lineart_detect), generate_size)
87
+ canny_image = canny_detector(input_image, int(generate_size * canny_detect), generate_size)
88
 
89
+ cond_image = [lineart_image, canny_image]
90
+ cond_scale = [cond_scale1, cond_scale2]
 
91
 
92
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
93
  generated_image = basepipeline(
 
100
  guidance_scale=guidance_scale,
101
  strength=strength,
102
  num_inference_steps=num_steps,
103
+ controlnet_conditioning_scale=cond_scale,
104
  ).images[0]
105
 
106
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
138
  with gr.Column():
139
  num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
140
  guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
 
141
  with gr.Column():
142
+ strength = gr.Slider(minimum=0, maximum=3, value=0.2, step=0.1, label="Strength")
143
  with gr.Accordion("Advanced Options", open=False):
144
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
145
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
146
  seed = gr.Number(label="Seed", value=8)
147
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
148
+ cond_scale1 = gr.Slider(minimum=0, maximum=3, value=0.8, step=0.1, label="Cond_scale1")
149
+ cond_scale2 = gr.Slider(minimum=0, maximum=3, value=0.3, step=0.1, label="Cond_scale2")
150
+ lineart_detect = gr.Slider(minimum=0, maximum=1, value=0.375, step=0.01, label="Lineart Detect")
151
+ canny_detect = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="Canny Detect")
152
  g_btn = gr.Button("Edit Image")
153
 
154
  with gr.Row():
 
167
  outputs=[origin_area_image, croper],
168
  ).success(
169
  fn=image_to_image,
170
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, strength, generate_size, cond_scale1, cond_scale2, lineart_detect, canny_detect],
171
  outputs=[generated_image, generated_cost],
172
  ).success(
173
  fn=restore_result,