hysts HF staff commited on
Commit
cb229bd
1 Parent(s): ea424ac

Add simple mode

Browse files
Files changed (3) hide show
  1. app.py +12 -3
  2. model.py +31 -6
  3. style.css +4 -0
app.py CHANGED
@@ -59,10 +59,16 @@ def main():
59
 
60
  with gr.Blocks(css='style.css') as demo:
61
  gr.Markdown(TITLE)
62
- gr.Markdown(DESCRIPTION)
63
 
64
  with gr.Tabs():
65
- with gr.TabItem('App'):
 
 
 
 
 
 
 
66
  with gr.Row():
67
  with gr.Column():
68
  with gr.Group():
@@ -87,7 +93,7 @@ def main():
87
  label='Seed')
88
  run_button = gr.Button('Run')
89
  with gr.Column():
90
- result = gr.Image(label='Result', elem_id='result')
91
 
92
  with gr.TabItem('Sample Images'):
93
  with gr.Row():
@@ -119,6 +125,9 @@ def main():
119
  scheduler_type,
120
  ],
121
  outputs=None)
 
 
 
122
  run_button.click(fn=model.run,
123
  inputs=[
124
  model_name,
 
59
 
60
  with gr.Blocks(css='style.css') as demo:
61
  gr.Markdown(TITLE)
 
62
 
63
  with gr.Tabs():
64
+ with gr.TabItem('Simple Mode'):
65
+ run_button_simple = gr.Button('Generate')
66
+ result_simple = gr.Image(show_label=False,
67
+ elem_id='result-grid')
68
+
69
+ with gr.TabItem('Advanced Mode'):
70
+ gr.Markdown(DESCRIPTION)
71
+
72
  with gr.Row():
73
  with gr.Column():
74
  with gr.Group():
 
93
  label='Seed')
94
  run_button = gr.Button('Run')
95
  with gr.Column():
96
+ result = gr.Image(show_label=False, elem_id='result')
97
 
98
  with gr.TabItem('Sample Images'):
99
  with gr.Row():
 
125
  scheduler_type,
126
  ],
127
  outputs=None)
128
+ run_button_simple.click(fn=model.run_simple,
129
+ inputs=None,
130
+ outputs=result_simple)
131
  run_button.click(fn=model.run,
132
  inputs=[
133
  model_name,
model.py CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
2
 
3
  import logging
4
  import os
 
5
  import sys
6
 
 
7
  import PIL.Image
8
  import torch
9
  from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
@@ -37,6 +39,7 @@ class Model:
37
  self.scheduler_type = 'DDIM'
38
  self.pipeline = self._load_pipeline(self.model_name,
39
  self.scheduler_type)
 
40
 
41
  def _load_pipeline(self, model_name: str,
42
  scheduler_type: str) -> DiffusionPipeline:
@@ -78,18 +81,21 @@ class Model:
78
  for name in self.MODEL_NAMES:
79
  self._load_pipeline(name, 'DDPM')
80
 
81
- def generate(self, seed: int, num_steps: int) -> PIL.Image.Image:
 
 
 
82
  logger.info('--- generate ---')
83
  logger.info(f'{seed=}, {num_steps=}')
84
 
85
  torch.manual_seed(seed)
86
  if self.scheduler_type == 'DDPM':
87
- res = self.pipeline(batch_size=1,
88
- torch_device=self.device)['sample'][0]
89
  elif self.scheduler_type in ['DDIM', 'PNDM']:
90
- res = self.pipeline(batch_size=1,
91
  torch_device=self.device,
92
- num_inference_steps=num_steps)['sample'][0]
93
  else:
94
  raise ValueError
95
 
@@ -106,4 +112,23 @@ class Model:
106
  self.set_pipeline(model_name, scheduler_type)
107
  if scheduler_type == 'PNDM':
108
  num_steps = max(4, min(num_steps, 100))
109
- return self.generate(seed, num_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import logging
4
  import os
5
+ import random
6
  import sys
7
 
8
+ import numpy as np
9
  import PIL.Image
10
  import torch
11
  from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
 
39
  self.scheduler_type = 'DDIM'
40
  self.pipeline = self._load_pipeline(self.model_name,
41
  self.scheduler_type)
42
+ self.rng = random.Random()
43
 
44
  def _load_pipeline(self, model_name: str,
45
  scheduler_type: str) -> DiffusionPipeline:
 
81
  for name in self.MODEL_NAMES:
82
  self._load_pipeline(name, 'DDPM')
83
 
84
+ def generate(self,
85
+ seed: int,
86
+ num_steps: int,
87
+ num_images: int = 1) -> list[PIL.Image.Image]:
88
  logger.info('--- generate ---')
89
  logger.info(f'{seed=}, {num_steps=}')
90
 
91
  torch.manual_seed(seed)
92
  if self.scheduler_type == 'DDPM':
93
+ res = self.pipeline(batch_size=num_images,
94
+ torch_device=self.device)['sample']
95
  elif self.scheduler_type in ['DDIM', 'PNDM']:
96
+ res = self.pipeline(batch_size=num_images,
97
  torch_device=self.device,
98
+ num_inference_steps=num_steps)['sample']
99
  else:
100
  raise ValueError
101
 
 
112
  self.set_pipeline(model_name, scheduler_type)
113
  if scheduler_type == 'PNDM':
114
  num_steps = max(4, min(num_steps, 100))
115
+ return self.generate(seed, num_steps)[0]
116
+
117
+ @staticmethod
118
+ def to_grid(images: list[PIL.Image.Image],
119
+ ncols: int = 2) -> PIL.Image.Image:
120
+ images = [np.asarray(image) for image in images]
121
+ nrows = (len(images) + ncols - 1) // ncols
122
+ h, w = images[0].shape[:2]
123
+ d = nrows * ncols - len(images)
124
+ if d > 0:
125
+ images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d
126
+ grid = np.asarray(images).reshape(2, 2, h, w, 3).transpose(
127
+ 0, 2, 1, 3, 4).reshape(2 * h, 2 * w, 3)
128
+ return PIL.Image.fromarray(grid)
129
+
130
+ def run_simple(self) -> PIL.Image.Image:
131
+ self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
132
+ seed = self.rng.randint(0, 100000)
133
+ images = self.generate(seed, num_steps=10, num_images=4)
134
+ return self.to_grid(images, 2)
style.css CHANGED
@@ -1,6 +1,10 @@
1
  h1 {
2
  text-align: center;
3
  }
 
 
 
 
4
  div#result {
5
  max-width: 400px;
6
  max-height: 400px;
 
1
  h1 {
2
  text-align: center;
3
  }
4
+ div#result-grid {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
  div#result {
9
  max-width: 400px;
10
  max-height: 400px;