hysts HF staff commited on
Commit
31cdb53
1 Parent(s): 53ddf51

Apply superresolution with Real-ESRGAN

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. model.py +19 -3
app.py CHANGED
@@ -60,6 +60,8 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
60
  step=1,
61
  value=1234,
62
  label='Seed')
 
 
63
  run_button = gr.Button('Run')
64
  with gr.Column():
65
  with gr.Tabs():
@@ -80,6 +82,7 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
80
  num_steps,
81
  randomize_seed,
82
  seed,
 
83
  ],
84
  outputs=[
85
  result,
 
60
  step=1,
61
  value=1234,
62
  label='Seed')
63
+ superresolve = gr.Checkbox(value=False,
64
+ label='Superresolve')
65
  run_button = gr.Button('Run')
66
  with gr.Column():
67
  with gr.Tabs():
 
82
  num_steps,
83
  randomize_seed,
84
  seed,
85
+ superresolve,
86
  ],
87
  outputs=[
88
  result,
model.py CHANGED
@@ -6,6 +6,7 @@ import random
6
  import sys
7
  import tempfile
8
 
 
9
  import imageio
10
  import numpy as np
11
  import PIL.Image
@@ -44,6 +45,8 @@ class Model:
44
  self.scheduler_type)
45
  self.rng = random.Random()
46
 
 
 
47
  @staticmethod
48
  def _load_pipeline(model_name: str,
49
  scheduler_type: str) -> DiffusionPipeline:
@@ -140,17 +143,29 @@ class Model:
140
  writer.close()
141
 
142
  logger.info('--- done ---')
143
- return res, out_file.name
 
 
 
 
 
 
 
 
 
 
144
 
145
  def run(self, model_name: str, scheduler_type: str, num_steps: int,
146
- randomize_seed: bool,
147
- seed: int) -> tuple[PIL.Image.Image, int, str]:
148
  self.set_pipeline(model_name, scheduler_type)
149
  if scheduler_type == 'PNDM':
150
  num_steps = max(4, min(num_steps, 100))
151
  if randomize_seed:
152
  seed = self.rng.randint(0, 100000)
153
  res, filename = self.generate_with_video(seed, num_steps)
 
 
154
  return res, seed, filename
155
 
156
  @staticmethod
@@ -169,4 +184,5 @@ class Model:
169
  self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
170
  seed = self.rng.randint(0, 1000000)
171
  images = self.generate(seed, num_steps=10, num_images=4)
 
172
  return self.to_grid(images, 2)
 
6
  import sys
7
  import tempfile
8
 
9
+ import gradio as gr
10
  import imageio
11
  import numpy as np
12
  import PIL.Image
 
45
  self.scheduler_type)
46
  self.rng = random.Random()
47
 
48
+ self.real_esrgan = gr.Interface.load('spaces/hysts/Real-ESRGAN-anime')
49
+
50
  @staticmethod
51
  def _load_pipeline(model_name: str,
52
  scheduler_type: str) -> DiffusionPipeline:
 
143
  writer.close()
144
 
145
  logger.info('--- done ---')
146
+ return PIL.Image.fromarray(res), out_file.name
147
+
148
+ def superresolve(self, image: PIL.Image.Image) -> PIL.Image.Image:
149
+ logger.info('--- superresolve ---')
150
+
151
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
152
+ image.save(f.name)
153
+ out_file = self.real_esrgan(f.name)
154
+
155
+ logger.info('--- done ---')
156
+ return PIL.Image.open(out_file)
157
 
158
  def run(self, model_name: str, scheduler_type: str, num_steps: int,
159
+ randomize_seed: bool, seed: int,
160
+ superresolve: bool) -> tuple[PIL.Image.Image, int, str]:
161
  self.set_pipeline(model_name, scheduler_type)
162
  if scheduler_type == 'PNDM':
163
  num_steps = max(4, min(num_steps, 100))
164
  if randomize_seed:
165
  seed = self.rng.randint(0, 100000)
166
  res, filename = self.generate_with_video(seed, num_steps)
167
+ if superresolve:
168
+ res = self.superresolve(res)
169
  return res, seed, filename
170
 
171
  @staticmethod
 
184
  self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
185
  seed = self.rng.randint(0, 1000000)
186
  images = self.generate(seed, num_steps=10, num_images=4)
187
+ images = [self.superresolve(image) for image in images]
188
  return self.to_grid(images, 2)