zhiweili commited on
Commit
0a87234
1 Parent(s): 05649e7

change to inpaint 15

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_haircolor_inpaint_15.py +8 -5
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_haircolor_img2img import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_haircolor_inpaint_15 import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_haircolor_inpaint_15.py CHANGED
@@ -30,7 +30,7 @@ BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-inpainting"
30
 
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
- DEFAULT_EDIT_PROMPT = "change hair to linen blonde"
34
  DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, poorly drawn face, bad face, fused face, ugly face, worst face, asymmetrical, unrealistic skin texture, bad proportions, out of frame, poorly drawn hands, cloned face, double face"
35
 
36
  DEFAULT_CATEGORY = "hair"
@@ -77,6 +77,7 @@ def image_to_image(
77
  seed: int,
78
  num_steps: int,
79
  guidance_scale: float,
 
80
  generate_size: int,
81
  cond_scale1: float = 1.0,
82
  cond_scale2: float = 0.6,
@@ -85,9 +86,9 @@ def image_to_image(
85
  time_cost_str = ''
86
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
87
  # canny_image = canny_detector(input_image, int(generate_size*1), generate_size)
88
- lineart_image = lineart_detector(input_image, 384, generate_size)
89
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
90
- pidiNet_image = pidiNet_detector(input_image, 512, generate_size)
91
  control_image = [lineart_image, pidiNet_image]
92
 
93
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
@@ -101,6 +102,7 @@ def image_to_image(
101
  height=generate_size,
102
  width=generate_size,
103
  guidance_scale=guidance_scale,
 
104
  num_inference_steps=num_steps,
105
  controlnet_conditioning_scale=[cond_scale1, cond_scale2],
106
  ).images[0]
@@ -136,11 +138,12 @@ def create_demo() -> gr.Blocks:
136
  with gr.Row():
137
  with gr.Column():
138
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
139
- generate_size = gr.Number(label="Generate Size", value=512)
140
  with gr.Column():
141
  num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
142
  guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
143
  with gr.Column():
 
144
  with gr.Accordion("Advanced Options", open=False):
145
  cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond Scale1")
146
  cond_scale2 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond Scale2")
@@ -167,7 +170,7 @@ def create_demo() -> gr.Blocks:
167
  outputs=[origin_area_image, mask_image, croper],
168
  ).success(
169
  fn=image_to_image,
170
- inputs=[origin_area_image, mask_image, edit_prompt,seed, num_steps, guidance_scale, generate_size, cond_scale1, cond_scale2],
171
  outputs=[generated_image, generated_cost],
172
  ).success(
173
  fn=restore_result,
 
30
 
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
+ DEFAULT_EDIT_PROMPT = "RAW photo, Fujifilm XT3, sharp hair, high resolution hair, hair tones, natural hair, magazine hair, white color hair"
34
  DEFAULT_NEGATIVE_PROMPT = "worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, poorly drawn face, bad face, fused face, ugly face, worst face, asymmetrical, unrealistic skin texture, bad proportions, out of frame, poorly drawn hands, cloned face, double face"
35
 
36
  DEFAULT_CATEGORY = "hair"
 
77
  seed: int,
78
  num_steps: int,
79
  guidance_scale: float,
80
+ strength: float,
81
  generate_size: int,
82
  cond_scale1: float = 1.0,
83
  cond_scale2: float = 0.6,
 
86
  time_cost_str = ''
87
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
88
  # canny_image = canny_detector(input_image, int(generate_size*1), generate_size)
89
+ lineart_image = lineart_detector(input_image, 768, generate_size)
90
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
91
+ pidiNet_image = pidiNet_detector(input_image, 768, generate_size)
92
  control_image = [lineart_image, pidiNet_image]
93
 
94
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
102
  height=generate_size,
103
  width=generate_size,
104
  guidance_scale=guidance_scale,
105
+ strength=strength,
106
  num_inference_steps=num_steps,
107
  controlnet_conditioning_scale=[cond_scale1, cond_scale2],
108
  ).images[0]
 
138
  with gr.Row():
139
  with gr.Column():
140
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
141
+ generate_size = gr.Number(label="Generate Size", value=768)
142
  with gr.Column():
143
  num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
144
  guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
145
  with gr.Column():
146
+ strength = gr.Slider(minimum=0, maximum=2, value=0.2, step=0.1, label="Strength")
147
  with gr.Accordion("Advanced Options", open=False):
148
  cond_scale1 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond Scale1")
149
  cond_scale2 = gr.Slider(minimum=0, maximum=3, value=1.2, step=0.1, label="Cond Scale2")
 
170
  outputs=[origin_area_image, mask_image, croper],
171
  ).success(
172
  fn=image_to_image,
173
+ inputs=[origin_area_image, mask_image, edit_prompt,seed, num_steps, guidance_scale, strength, generate_size, cond_scale1, cond_scale2],
174
  outputs=[generated_image, generated_cost],
175
  ).success(
176
  fn=restore_result,