zhiweili commited on
Commit
02c0c4b
1 Parent(s): 2f53645

change app

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_text2img.py +23 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_haircolor import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_text2img import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_text2img.py CHANGED
@@ -18,6 +18,8 @@ from diffusers import (
18
  from controlnet_aux import (
19
  LineartDetector,
20
  CannyDetector,
 
 
21
  )
22
 
23
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -31,8 +33,16 @@ DEFAULT_CATEGORY = "hair"
31
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
32
  lineart_detector = lineart_detector.to(DEVICE)
33
 
 
 
 
34
  canndy_detector = CannyDetector()
35
 
 
 
 
 
 
36
  adapters = MultiAdapter(
37
  [
38
  T2IAdapter.from_pretrained(
@@ -45,6 +55,11 @@ adapters = MultiAdapter(
45
  torch_dtype=torch.float16,
46
  varient="fp16",
47
  ),
 
 
 
 
 
48
  ]
49
  )
50
  adapters = adapters.to(torch.float16)
@@ -71,6 +86,7 @@ def image_to_image(
71
  generate_size: int,
72
  lineart_scale: float = 1.0,
73
  canny_scale: float = 0.5,
 
74
  ):
75
  run_task_time = 0
76
  time_cost_str = ''
@@ -79,9 +95,11 @@ def image_to_image(
79
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
80
  canny_image = canndy_detector(input_image, 384, generate_size)
81
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
82
 
83
- cond_image = [lineart_image, canny_image]
84
- cond_scale = [lineart_scale, canny_scale]
85
 
86
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
87
  generated_image = basepipeline(
@@ -127,8 +145,9 @@ def create_demo() -> gr.Blocks:
127
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
128
  with gr.Column():
129
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
130
- lineart_scale = gr.Slider(minimum=0, maximum=2, value=0.3, step=0.1, label="Lineart Scale")
131
  canny_scale = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Canny Scale")
 
132
  g_btn = gr.Button("Edit Image")
133
 
134
  with gr.Row():
@@ -147,7 +166,7 @@ def create_demo() -> gr.Blocks:
147
  outputs=[origin_area_image, croper],
148
  ).success(
149
  fn=image_to_image,
150
- inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, lineart_scale, canny_scale],
151
  outputs=[generated_image, generated_cost],
152
  ).success(
153
  fn=restore_result,
 
18
  from controlnet_aux import (
19
  LineartDetector,
20
  CannyDetector,
21
+ PidiNetDetector,
22
+ MidasDetector,
23
  )
24
 
25
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
 
33
  lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
34
  lineart_detector = lineart_detector.to(DEVICE)
35
 
36
+ pidinet_detector = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
37
+ pidinet_detector = pidinet_detector.to(DEVICE)
38
+
39
  canndy_detector = CannyDetector()
40
 
41
+ midas_detector = MidasDetector.from_pretrained(
42
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
43
+ )
44
+ midas_detector = midas_detector.to(DEVICE)
45
+
46
  adapters = MultiAdapter(
47
  [
48
  T2IAdapter.from_pretrained(
 
55
  torch_dtype=torch.float16,
56
  varient="fp16",
57
  ),
58
+ T2IAdapter.from_pretrained(
59
+ "TencentARC/t2i-adapter-sketch-sdxl-1.0",
60
+ torch_dtype=torch.float16,
61
+ varient="fp16",
62
+ ),
63
  ]
64
  )
65
  adapters = adapters.to(torch.float16)
 
86
  generate_size: int,
87
  lineart_scale: float = 1.0,
88
  canny_scale: float = 0.5,
89
+ sketch_scale: float = 1.0,
90
  ):
91
  run_task_time = 0
92
  time_cost_str = ''
 
95
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
96
  canny_image = canndy_detector(input_image, 384, generate_size)
97
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
98
+ sketch_image = pidinet_detector(input_image, 512, generate_size)
99
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
100
 
101
+ cond_image = [lineart_image, canny_image, sketch_image]
102
+ cond_scale = [lineart_scale, canny_scale, sketch_scale]
103
 
104
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
105
  generated_image = basepipeline(
 
145
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
146
  with gr.Column():
147
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
148
+ lineart_scale = gr.Slider(minimum=0, maximum=2, value=1, step=0.1, label="Lineart Scale")
149
  canny_scale = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Canny Scale")
150
+ sketch_scale = gr.Slider(minimum=0, maximum=2, value=1, step=0.1, label="Sketch Scale")
151
  g_btn = gr.Button("Edit Image")
152
 
153
  with gr.Row():
 
166
  outputs=[origin_area_image, croper],
167
  ).success(
168
  fn=image_to_image,
169
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, lineart_scale, canny_scale, sketch_scale],
170
  outputs=[generated_image, generated_cost],
171
  ).success(
172
  fn=restore_result,