zhiweili commited on
Commit
813fcc1
1 Parent(s): 4220acb

add enhance utils

Browse files
Files changed (5) hide show
  1. app.py +3 -3
  2. app_base.py +2 -32
  3. app_haircolor.py +4 -24
  4. enhance_utils.py +41 -0
  5. inversion_run_adapter.py +0 -9
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
 
3
- # from app_base import create_demo as create_demo_face
4
  from app_haircolor import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
8
- # with gr.Tab(label="Face"):
9
- # create_demo_face()
10
  with gr.Tab(label="Hair Color"):
11
  create_demo_haircolor()
12
 
 
1
  import gradio as gr
2
 
3
+ from app_base import create_demo as create_demo_face
4
  from app_haircolor import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
8
+ with gr.Tab(label="Face"):
9
+ create_demo_face()
10
  with gr.Tab(label="Hair Color"):
11
  create_demo_haircolor()
12
 
app_base.py CHANGED
@@ -2,19 +2,13 @@ import spaces
2
  import gradio as gr
3
  import time
4
  import torch
5
- import os
6
- import numpy as np
7
- import cv2
8
 
9
  from PIL import Image
10
  from segment_utils import(
11
  segment_image,
12
  restore_result,
13
  )
14
- from gfpgan.utils import GFPGANer
15
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
16
- from realesrgan.utils import RealESRGANer
17
-
18
 
19
  DEFAULT_SRC_PROMPT = "a woman, photo"
20
  DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
@@ -25,12 +19,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  def create_demo() -> gr.Blocks:
27
  from inversion_run_base import run as base_run
28
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
- model_path = 'realesr-general-x4v3.pth'
30
- half = True if torch.cuda.is_available() else False
31
- upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
32
-
33
- face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
34
 
35
  @spaces.GPU(duration=10)
36
  def image_to_image(
@@ -65,7 +53,7 @@ def create_demo() -> gr.Blocks:
65
  adapter_weights,
66
  )
67
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
68
- enhanced_image = enhance(res_image, enhance_face)
69
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
70
 
71
  return enhanced_image, res_image, time_cost_str
@@ -81,24 +69,6 @@ def create_demo() -> gr.Blocks:
81
  run_task_time = now_time
82
  return run_task_time, time_cost_str
83
 
84
-
85
- def enhance(
86
- pil_image: Image,
87
- enhance_face: bool = True,
88
- ):
89
- img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
90
-
91
- h, w = img.shape[0:2]
92
- if h < 300:
93
- img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
94
- if enhance_face:
95
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
96
- else:
97
- output, _ = upsampler.enhance(img, outscale=2)
98
- pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
99
-
100
- return pil_output
101
-
102
  with gr.Blocks() as demo:
103
  croper = gr.State()
104
  with gr.Row():
 
2
  import gradio as gr
3
  import time
4
  import torch
 
 
 
5
 
6
  from PIL import Image
7
  from segment_utils import(
8
  segment_image,
9
  restore_result,
10
  )
11
+ from enhance_utils import enhance_image
 
 
 
12
 
13
  DEFAULT_SRC_PROMPT = "a woman, photo"
14
  DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
 
19
 
20
  def create_demo() -> gr.Blocks:
21
  from inversion_run_base import run as base_run
 
 
 
 
 
 
22
 
23
  @spaces.GPU(duration=10)
24
  def image_to_image(
 
53
  adapter_weights,
54
  )
55
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
56
+ enhanced_image = enhance_image(res_image, enhance_face)
57
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
58
 
59
  return enhanced_image, res_image, time_cost_str
 
69
  run_task_time = now_time
70
  return run_task_time, time_cost_str
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Blocks() as demo:
73
  croper = gr.State()
74
  with gr.Row():
app_haircolor.py CHANGED
@@ -10,9 +10,7 @@ from segment_utils import(
10
  segment_image,
11
  restore_result,
12
  )
13
- from gfpgan.utils import GFPGANer
14
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
15
- from realesrgan.utils import RealESRGANer
16
 
17
 
18
  DEFAULT_SRC_PROMPT = "a woman"
@@ -23,11 +21,7 @@ DEFAULT_CATEGORY = "hair"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  def create_demo() -> gr.Blocks:
26
- from inversion_run_adapter import run as realvxl_run
27
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
28
- model_path = 'realesr-general-x4v3.pth'
29
- half = True if torch.cuda.is_available() else False
30
- upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
31
 
32
  @spaces.GPU(duration=10)
33
  def image_to_image(
@@ -48,7 +42,7 @@ def create_demo() -> gr.Blocks:
48
  run_task_time = 0
49
  time_cost_str = ''
50
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
51
- run_model = realvxl_run
52
  res_image = run_model(
53
  input_image,
54
  input_image_prompt,
@@ -65,7 +59,7 @@ def create_demo() -> gr.Blocks:
65
  sketch_scale,
66
  )
67
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
68
- enhanced_image = enhance(res_image)
69
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
70
 
71
  return enhanced_image, res_image, time_cost_str
@@ -81,20 +75,6 @@ def create_demo() -> gr.Blocks:
81
  run_task_time = now_time
82
  return run_task_time, time_cost_str
83
 
84
-
85
- def enhance(
86
- pil_image: Image,
87
- ):
88
- img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
89
-
90
- h, w = img.shape[0:2]
91
- if h < 300:
92
- img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
93
- output, _ = upsampler.enhance(img, outscale=2)
94
- pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
95
-
96
- return pil_output
97
-
98
  with gr.Blocks() as demo:
99
  croper = gr.State()
100
  with gr.Row():
 
10
  segment_image,
11
  restore_result,
12
  )
13
+ from enhance_utils import enhance_image
 
 
14
 
15
 
16
  DEFAULT_SRC_PROMPT = "a woman"
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  def create_demo() -> gr.Blocks:
24
+ from inversion_run_adapter import run as adapter_run
 
 
 
 
25
 
26
  @spaces.GPU(duration=10)
27
  def image_to_image(
 
42
  run_task_time = 0
43
  time_cost_str = ''
44
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
45
+ run_model = adapter_run
46
  res_image = run_model(
47
  input_image,
48
  input_image_prompt,
 
59
  sketch_scale,
60
  )
61
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
62
+ enhanced_image = enhance_image(res_image, False)
63
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
64
 
65
  return enhanced_image, res_image, time_cost_str
 
75
  run_task_time = now_time
76
  return run_task_time, time_cost_str
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with gr.Blocks() as demo:
79
  croper = gr.State()
80
  with gr.Row():
enhance_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from gfpgan.utils import GFPGANer
8
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
9
+ from realesrgan.utils import RealESRGANer
10
+
11
+ os.system("pip freeze")
12
+ if not os.path.exists('GFPGANv1.4.pth'):
13
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
14
+ if not os.path.exists('realesr-general-x4v3.pth'):
15
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
16
+
17
+ os.makedirs('output', exist_ok=True)
18
+
19
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
20
+ model_path = 'realesr-general-x4v3.pth'
21
+ half = True if torch.cuda.is_available() else False
22
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
23
+
24
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2)
25
+
26
+ def enhance_image(
27
+ pil_image: Image,
28
+ enhance_face: bool = True,
29
+ ):
30
+ img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
31
+
32
+ h, w = img.shape[0:2]
33
+ if h < 300:
34
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
35
+ if enhance_face:
36
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
37
+ else:
38
+ output, _ = upsampler.enhance(img, outscale=2)
39
+ pil_output = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
40
+
41
+ return pil_output
inversion_run_adapter.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- import os
3
 
4
  from diffusers import (
5
  DDPMScheduler,
@@ -20,14 +19,6 @@ from config import get_config, get_num_steps_actual
20
  from functools import partial
21
  from compel import Compel, ReturnedEmbeddingsType
22
 
23
- os.system("pip freeze")
24
- if not os.path.exists('GFPGANv1.4.pth'):
25
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
26
- if not os.path.exists('realesr-general-x4v3.pth'):
27
- os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
28
-
29
- os.makedirs('output', exist_ok=True)
30
-
31
  class Object(object):
32
  pass
33
 
 
1
  import torch
 
2
 
3
  from diffusers import (
4
  DDPMScheduler,
 
19
  from functools import partial
20
  from compel import Compel, ReturnedEmbeddingsType
21
 
 
 
 
 
 
 
 
 
22
  class Object(object):
23
  pass
24