leonelhs commited on
Commit
0f528f7
1 Parent(s): e49a8d7

add upsampler

Browse files
Files changed (1) hide show
  1. app.py +124 -36
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
 
3
  import gradio as gr
4
  import torch
 
5
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
6
  from gfpgan.utils import GFPGANer
7
  from huggingface_hub import hf_hub_download
@@ -12,46 +13,115 @@ GFPGAN_REPO_ID = 'leonelhs/gfpgan'
12
 
13
  os.system("pip freeze")
14
 
15
- # background enhancer with RealESRGAN
16
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
17
- model_path = hf_hub_download(repo_id=REALESRGAN_REPO_ID, filename='realesr-general-x4v3.pth')
18
- half = True if torch.cuda.is_available() else False
19
- upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
20
 
 
 
 
 
 
 
21
 
22
- def download_model(file):
 
23
  return hf_hub_download(repo_id=GFPGAN_REPO_ID, filename=file)
24
 
25
 
26
- def predict(image, version, scale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  scale = int(scale)
28
- face_enhancer = None
29
-
30
- if version == 'v1.2':
31
- path = download_model('GFPGANv1.2.pth')
32
- face_enhancer = GFPGANer(
33
- model_path=path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
34
- elif version == 'v1.3':
35
- path = download_model('GFPGANv1.3.pth')
36
- face_enhancer = GFPGANer(
37
- model_path=path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
38
- elif version == 'v1.4':
39
- path = download_model('GFPGANv1.4.pth')
40
- face_enhancer = GFPGANer(
41
- model_path=path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
42
- elif version == 'RestoreFormer':
43
- path = download_model('RestoreFormer.pth')
44
- face_enhancer = GFPGANer(
45
- model_path=path, upscale=scale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
46
-
47
- _, _, output = face_enhancer.enhance(image, has_aligned=False, only_center_face=False, paste_back=True)
48
-
49
- return output
50
-
51
-
52
- title = "GFPGAN"
53
  description = r"""
54
- <b>Practical Face Restoration Algorithm</b>
55
  """
56
  article = r"""
57
  <center><span>[email protected] or [email protected]</span></center>
@@ -62,13 +132,31 @@ article = r"""
62
  demo = gr.Interface(
63
  predict, [
64
  gr.Image(type="numpy", label="Input"),
65
- gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  gr.Dropdown(["1", "2", "3", "4"], value="2", label="Rescaling factor")
67
  ], [
68
- gr.Image(type="numpy", label="Output", interactive=False)
 
69
  ],
70
  title=title,
71
  description=description,
72
  article=article)
73
 
74
- demo.queue().launch()
 
2
 
3
  import gradio as gr
4
  import torch
5
+ from basicsr.archs.rrdbnet_arch import RRDBNet
6
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
7
  from gfpgan.utils import GFPGANer
8
  from huggingface_hub import hf_hub_download
 
13
 
14
  os.system("pip freeze")
15
 
 
 
 
 
 
16
 
17
+ def showGPU():
18
+ if torch.cuda.is_available():
19
+ devices = torch.cuda.device_count()
20
+ current = torch.cuda.current_device()
21
+ return f"Running on GPU:{current} of {devices} total devices"
22
+ return "Running on CPU"
23
 
24
+
25
+ def download_model_gfpgan(file):
26
  return hf_hub_download(repo_id=GFPGAN_REPO_ID, filename=file)
27
 
28
 
29
+ def download_model_realesrgan(file):
30
+ return hf_hub_download(repo_id=REALESRGAN_REPO_ID, filename=file)
31
+
32
+
33
+ def select_upsampler(version, netscale=4):
34
+ model = None
35
+ dni_weight = None
36
+
37
+ version = version + ".pth"
38
+ model_path = download_model_realesrgan(version)
39
+
40
+ if version == 'RealESRGAN_x4plus.pth': # x4 RRDBNet model
41
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
42
+
43
+ if version == 'RealESRNet_x4plus.pth': # x4 RRDBNet model
44
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
45
+
46
+ if version == 'AI-Forever_x4plus.pth': # x4 RRDBNet model
47
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
48
+
49
+ if version == 'RealESRGAN_x4plus_anime_6B.pth': # x4 RRDBNet model with 6 blocks
50
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
51
+
52
+ if version == 'RealESRGAN_x2plus.pth': # x2 RRDBNet model
53
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
54
+ netscale = 2 # This is
55
+
56
+ if version == 'AI-Forever_x2plus.pth': # x2 RRDBNet model
57
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
58
+ netscale = 2 # This is
59
+
60
+ if version == 'realesr-animevideov3.pth': # x4 VGG-style model (XS size)
61
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
62
+
63
+ if version == 'realesr-general-x4v3.pth': # x4 VGG-style model (S size)
64
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
65
+ model_path = [
66
+ download_model_realesrgan("realesr-general-x4v3.pth"),
67
+ download_model_realesrgan("realesr-general-wdn-x4v3.pth")
68
+ ]
69
+ dni_weight = [0.2, 0.8]
70
+
71
+ half = True if torch.cuda.is_available() else False
72
+
73
+ return RealESRGANer(
74
+ scale=netscale,
75
+ model_path=model_path,
76
+ dni_weight=dni_weight,
77
+ model=model,
78
+ tile=0,
79
+ tile_pad=10,
80
+ pre_pad=0,
81
+ half=half,
82
+ gpu_id=0)
83
+
84
+
85
+ def select_face_enhancer(version, scale, upsampler):
86
+ if 'v1.2' in version:
87
+ model_path = download_model_gfpgan('GFPGANv1.2.pth')
88
+ return GFPGANer(
89
+ model_path=model_path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
90
+ elif 'v1.3' in version:
91
+ model_path = download_model_gfpgan('GFPGANv1.3.pth')
92
+ return GFPGANer(
93
+ model_path=model_path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
94
+ elif 'v1.4' in version:
95
+ model_path = download_model_gfpgan('GFPGANv1.4.pth')
96
+ return GFPGANer(
97
+ model_path=model_path, upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
98
+ elif 'RestoreFormer' in version:
99
+ model_path = download_model_gfpgan('RestoreFormer.pth')
100
+ return GFPGANer(
101
+ model_path=model_path, upscale=scale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
102
+
103
+
104
+ def predict(image, version_upsampler, version_enhancer, scale):
105
  scale = int(scale)
106
+
107
+ upsampler = select_upsampler(version_upsampler)
108
+
109
+ if "No additional" not in version_enhancer:
110
+ face_enhancer = select_face_enhancer(version_enhancer, scale, upsampler)
111
+ _, _, output = face_enhancer.enhance(image, has_aligned=False, only_center_face=False, paste_back=True)
112
+ else:
113
+ output, _ = upsampler.enhance(image, outscale=scale)
114
+
115
+ log = f"General enhance version: {version_upsampler}\n " \
116
+ f"Face enhance version: {version_enhancer} \n " \
117
+ f"Scale:{scale} \n {showGPU()}"
118
+
119
+ return output, log
120
+
121
+
122
+ title = "Super Face"
 
 
 
 
 
 
 
 
123
  description = r"""
124
+ <b>Practical Image Restoration Algorithm based on Real-ESRGAN, GFPGAN</b>
125
  """
126
  article = r"""
127
  <center><span>[email protected] or [email protected]</span></center>
 
132
  demo = gr.Interface(
133
  predict, [
134
  gr.Image(type="numpy", label="Input"),
135
+ gr.Dropdown([
136
+ 'RealESRGAN_x2plus',
137
+ 'RealESRGAN_x4plus',
138
+ 'RealESRNet_x4plus',
139
+ 'AI-Forever_x2plus',
140
+ 'AI-Forever_x4plus',
141
+ 'RealESRGAN_x4plus_anime_6B',
142
+ 'realesr-animevideov3',
143
+ 'realesr-general-x4v3'],
144
+ type="value", value='RealESRGAN_x4plus', label='General restoration algorithm', info="version"),
145
+ gr.Dropdown([
146
+ 'No additional face process',
147
+ 'GFPGANv1.2',
148
+ 'GFPGANv1.3',
149
+ 'GFPGANv1.4',
150
+ 'RestoreFormer'],
151
+ type="value", value='No additional face process', label='Special face restoration algorithm',
152
+ info="version"),
153
  gr.Dropdown(["1", "2", "3", "4"], value="2", label="Rescaling factor")
154
  ], [
155
+ gr.Image(type="numpy", label="Output", interactive=False),
156
+ gr.Textbox(label="log info")
157
  ],
158
  title=title,
159
  description=description,
160
  article=article)
161
 
162
+ demo.queue().launch(share=True, debug=True)