zhiweili commited on
Commit
7386f72
1 Parent(s): ace9e93

add hair color app

Browse files
Files changed (4) hide show
  1. app.py +3 -0
  2. app_base.py +1 -9
  3. app_haircolor.py +142 -0
  4. inversion_run_base.py +8 -0
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import gradio as gr
2
 
3
  from app_base import create_demo as create_demo_face
 
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
7
  with gr.Tab(label="Face"):
8
  create_demo_face()
 
 
9
 
10
  demo.launch()
 
1
  import gradio as gr
2
 
3
  from app_base import create_demo as create_demo_face
4
+ from app_haircolor import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
8
  with gr.Tab(label="Face"):
9
  create_demo_face()
10
+ with gr.Tab(label="Hair Color"):
11
+ create_demo_haircolor()
12
 
13
  demo.launch()
app_base.py CHANGED
@@ -24,14 +24,6 @@ DEFAULT_CATEGORY = "face"
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
- os.system("pip freeze")
28
- if not os.path.exists('GFPGANv1.4.pth'):
29
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
30
- if not os.path.exists('realesr-general-x4v3.pth'):
31
- os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
32
-
33
- os.makedirs('output', exist_ok=True)
34
-
35
  def create_demo() -> gr.Blocks:
36
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
37
  model_path = 'realesr-general-x4v3.pth'
@@ -40,7 +32,7 @@ def create_demo() -> gr.Blocks:
40
 
41
  face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
42
 
43
- @spaces.GPU(duration=15)
44
  def image_to_image(
45
  input_image: Image,
46
  input_image_prompt: str,
 
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
 
 
 
 
 
 
 
 
27
  def create_demo() -> gr.Blocks:
28
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
  model_path = 'realesr-general-x4v3.pth'
 
32
 
33
  face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
34
 
35
+ @spaces.GPU(duration=10)
36
  def image_to_image(
37
  input_image: Image,
38
  input_image_prompt: str,
app_haircolor.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+ import os
6
+ import numpy as np
7
+ import cv2
8
+
9
+ from PIL import Image
10
+ from inversion_run_base import run as base_run
11
+ from segment_utils import(
12
+ segment_image,
13
+ restore_result,
14
+ )
15
+ from gfpgan.utils import GFPGANer
16
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
17
+ from realesrgan.utils import RealESRGANer
18
+
19
+
20
+ DEFAULT_SRC_PROMPT = "a woman"
21
+ DEFAULT_EDIT_PROMPT = "a woman, with blue hair, 8k, high quality"
22
+
23
+ DEFAULT_CATEGORY = "hair"
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ def create_demo() -> gr.Blocks:
28
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
+ model_path = 'realesr-general-x4v3.pth'
30
+ half = True if torch.cuda.is_available() else False
31
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
32
+
33
+ @spaces.GPU(duration=10)
34
+ def image_to_image(
35
+ input_image: Image,
36
+ input_image_prompt: str,
37
+ edit_prompt: str,
38
+ seed: int,
39
+ w1: float,
40
+ num_steps: int,
41
+ start_step: int,
42
+ guidance_scale: float,
43
+ generate_size: int,
44
+ adapter_weights: float,
45
+ ):
46
+ w2 = 1.0
47
+ run_task_time = 0
48
+ time_cost_str = ''
49
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
50
+ run_model = base_run
51
+ res_image = run_model(
52
+ input_image,
53
+ input_image_prompt,
54
+ edit_prompt,
55
+ generate_size,
56
+ seed,
57
+ w1,
58
+ w2,
59
+ num_steps,
60
+ start_step,
61
+ guidance_scale,
62
+ adapter_weights,
63
+ )
64
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
65
+ enhanced_image = enhance(res_image)
66
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
67
+
68
+ return enhanced_image, res_image, time_cost_str
69
+
70
+ def get_time_cost(run_task_time, time_cost_str):
71
+ now_time = int(time.time()*1000)
72
+ if run_task_time == 0:
73
+ time_cost_str = 'start'
74
+ else:
75
+ if time_cost_str != '':
76
+ time_cost_str += f'-->'
77
+ time_cost_str += f'{now_time - run_task_time}'
78
+ run_task_time = now_time
79
+ return run_task_time, time_cost_str
80
+
81
+
82
+ def enhance(
83
+ pil_image: Image,
84
+ ):
85
+ img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
86
+
87
+ h, w = img.shape[0:2]
88
+ if h < 300:
89
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
90
+ output, _ = upsampler.enhance(img, outscale=2)
91
+ pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
92
+
93
+ return pil_output
94
+
95
+ with gr.Blocks() as demo:
96
+ croper = gr.State()
97
+ with gr.Row():
98
+ with gr.Column():
99
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
100
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
101
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
102
+ with gr.Column():
103
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
104
+ start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
105
+ with gr.Accordion("Advanced Options", open=False):
106
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=1, step=0.5, label="Guidance Scale")
107
+ generate_size = gr.Number(label="Generate Size", value=512)
108
+ mask_expansion = gr.Number(label="Mask Expansion", value=10, visible=True)
109
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
110
+ adapter_weights = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Adapter Weights", visible=False)
111
+ with gr.Column():
112
+ seed = gr.Number(label="Seed", value=8)
113
+ w1 = gr.Number(label="W1", value=2)
114
+ g_btn = gr.Button("Edit Image")
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ input_image = gr.Image(label="Input Image", type="pil")
119
+ with gr.Column():
120
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
121
+ download_path = gr.File(label="Download the output image", interactive=False)
122
+ with gr.Column():
123
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
124
+ enhanced_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
125
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
126
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
127
+
128
+ g_btn.click(
129
+ fn=segment_image,
130
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
131
+ outputs=[origin_area_image, croper],
132
+ ).success(
133
+ fn=image_to_image,
134
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, adapter_weights],
135
+ outputs=[enhanced_image, generated_image, generated_cost],
136
+ ).success(
137
+ fn=restore_result,
138
+ inputs=[croper, category, enhanced_image],
139
+ outputs=[restored_image, download_path],
140
+ )
141
+
142
+ return demo
inversion_run_base.py CHANGED
@@ -12,6 +12,14 @@ from config import get_config, get_num_steps_actual
12
  from functools import partial
13
  from compel import Compel, ReturnedEmbeddingsType
14
 
 
 
 
 
 
 
 
 
15
  class Object(object):
16
  pass
17
 
 
12
  from functools import partial
13
  from compel import Compel, ReturnedEmbeddingsType
14
 
15
+ os.system("pip freeze")
16
+ if not os.path.exists('GFPGANv1.4.pth'):
17
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
18
+ if not os.path.exists('realesr-general-x4v3.pth'):
19
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
20
+
21
+ os.makedirs('output', exist_ok=True)
22
+
23
  class Object(object):
24
  pass
25