Spaces:
Runtime error
Runtime error
Add simple mode
Browse files
app.py
CHANGED
@@ -59,10 +59,16 @@ def main():
|
|
59 |
|
60 |
with gr.Blocks(css='style.css') as demo:
|
61 |
gr.Markdown(TITLE)
|
62 |
-
gr.Markdown(DESCRIPTION)
|
63 |
|
64 |
with gr.Tabs():
|
65 |
-
with gr.TabItem('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
with gr.Row():
|
67 |
with gr.Column():
|
68 |
with gr.Group():
|
@@ -87,7 +93,7 @@ def main():
|
|
87 |
label='Seed')
|
88 |
run_button = gr.Button('Run')
|
89 |
with gr.Column():
|
90 |
-
result = gr.Image(
|
91 |
|
92 |
with gr.TabItem('Sample Images'):
|
93 |
with gr.Row():
|
@@ -119,6 +125,9 @@ def main():
|
|
119 |
scheduler_type,
|
120 |
],
|
121 |
outputs=None)
|
|
|
|
|
|
|
122 |
run_button.click(fn=model.run,
|
123 |
inputs=[
|
124 |
model_name,
|
|
|
59 |
|
60 |
with gr.Blocks(css='style.css') as demo:
|
61 |
gr.Markdown(TITLE)
|
|
|
62 |
|
63 |
with gr.Tabs():
|
64 |
+
with gr.TabItem('Simple Mode'):
|
65 |
+
run_button_simple = gr.Button('Generate')
|
66 |
+
result_simple = gr.Image(show_label=False,
|
67 |
+
elem_id='result-grid')
|
68 |
+
|
69 |
+
with gr.TabItem('Advanced Mode'):
|
70 |
+
gr.Markdown(DESCRIPTION)
|
71 |
+
|
72 |
with gr.Row():
|
73 |
with gr.Column():
|
74 |
with gr.Group():
|
|
|
93 |
label='Seed')
|
94 |
run_button = gr.Button('Run')
|
95 |
with gr.Column():
|
96 |
+
result = gr.Image(show_label=False, elem_id='result')
|
97 |
|
98 |
with gr.TabItem('Sample Images'):
|
99 |
with gr.Row():
|
|
|
125 |
scheduler_type,
|
126 |
],
|
127 |
outputs=None)
|
128 |
+
run_button_simple.click(fn=model.run_simple,
|
129 |
+
inputs=None,
|
130 |
+
outputs=result_simple)
|
131 |
run_button.click(fn=model.run,
|
132 |
inputs=[
|
133 |
model_name,
|
model.py
CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
|
|
2 |
|
3 |
import logging
|
4 |
import os
|
|
|
5 |
import sys
|
6 |
|
|
|
7 |
import PIL.Image
|
8 |
import torch
|
9 |
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
|
@@ -37,6 +39,7 @@ class Model:
|
|
37 |
self.scheduler_type = 'DDIM'
|
38 |
self.pipeline = self._load_pipeline(self.model_name,
|
39 |
self.scheduler_type)
|
|
|
40 |
|
41 |
def _load_pipeline(self, model_name: str,
|
42 |
scheduler_type: str) -> DiffusionPipeline:
|
@@ -78,18 +81,21 @@ class Model:
|
|
78 |
for name in self.MODEL_NAMES:
|
79 |
self._load_pipeline(name, 'DDPM')
|
80 |
|
81 |
-
def generate(self,
|
|
|
|
|
|
|
82 |
logger.info('--- generate ---')
|
83 |
logger.info(f'{seed=}, {num_steps=}')
|
84 |
|
85 |
torch.manual_seed(seed)
|
86 |
if self.scheduler_type == 'DDPM':
|
87 |
-
res = self.pipeline(batch_size=
|
88 |
-
torch_device=self.device)['sample']
|
89 |
elif self.scheduler_type in ['DDIM', 'PNDM']:
|
90 |
-
res = self.pipeline(batch_size=
|
91 |
torch_device=self.device,
|
92 |
-
num_inference_steps=num_steps)['sample']
|
93 |
else:
|
94 |
raise ValueError
|
95 |
|
@@ -106,4 +112,23 @@ class Model:
|
|
106 |
self.set_pipeline(model_name, scheduler_type)
|
107 |
if scheduler_type == 'PNDM':
|
108 |
num_steps = max(4, min(num_steps, 100))
|
109 |
-
return self.generate(seed, num_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import logging
|
4 |
import os
|
5 |
+
import random
|
6 |
import sys
|
7 |
|
8 |
+
import numpy as np
|
9 |
import PIL.Image
|
10 |
import torch
|
11 |
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
|
|
|
39 |
self.scheduler_type = 'DDIM'
|
40 |
self.pipeline = self._load_pipeline(self.model_name,
|
41 |
self.scheduler_type)
|
42 |
+
self.rng = random.Random()
|
43 |
|
44 |
def _load_pipeline(self, model_name: str,
|
45 |
scheduler_type: str) -> DiffusionPipeline:
|
|
|
81 |
for name in self.MODEL_NAMES:
|
82 |
self._load_pipeline(name, 'DDPM')
|
83 |
|
84 |
+
def generate(self,
|
85 |
+
seed: int,
|
86 |
+
num_steps: int,
|
87 |
+
num_images: int = 1) -> list[PIL.Image.Image]:
|
88 |
logger.info('--- generate ---')
|
89 |
logger.info(f'{seed=}, {num_steps=}')
|
90 |
|
91 |
torch.manual_seed(seed)
|
92 |
if self.scheduler_type == 'DDPM':
|
93 |
+
res = self.pipeline(batch_size=num_images,
|
94 |
+
torch_device=self.device)['sample']
|
95 |
elif self.scheduler_type in ['DDIM', 'PNDM']:
|
96 |
+
res = self.pipeline(batch_size=num_images,
|
97 |
torch_device=self.device,
|
98 |
+
num_inference_steps=num_steps)['sample']
|
99 |
else:
|
100 |
raise ValueError
|
101 |
|
|
|
112 |
self.set_pipeline(model_name, scheduler_type)
|
113 |
if scheduler_type == 'PNDM':
|
114 |
num_steps = max(4, min(num_steps, 100))
|
115 |
+
return self.generate(seed, num_steps)[0]
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def to_grid(images: list[PIL.Image.Image],
|
119 |
+
ncols: int = 2) -> PIL.Image.Image:
|
120 |
+
images = [np.asarray(image) for image in images]
|
121 |
+
nrows = (len(images) + ncols - 1) // ncols
|
122 |
+
h, w = images[0].shape[:2]
|
123 |
+
d = nrows * ncols - len(images)
|
124 |
+
if d > 0:
|
125 |
+
images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d
|
126 |
+
grid = np.asarray(images).reshape(2, 2, h, w, 3).transpose(
|
127 |
+
0, 2, 1, 3, 4).reshape(2 * h, 2 * w, 3)
|
128 |
+
return PIL.Image.fromarray(grid)
|
129 |
+
|
130 |
+
def run_simple(self) -> PIL.Image.Image:
|
131 |
+
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
|
132 |
+
seed = self.rng.randint(0, 100000)
|
133 |
+
images = self.generate(seed, num_steps=10, num_images=4)
|
134 |
+
return self.to_grid(images, 2)
|
style.css
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
h1 {
|
2 |
text-align: center;
|
3 |
}
|
|
|
|
|
|
|
|
|
4 |
div#result {
|
5 |
max-width: 400px;
|
6 |
max-height: 400px;
|
|
|
1 |
h1 {
|
2 |
text-align: center;
|
3 |
}
|
4 |
+
div#result-grid {
|
5 |
+
max-width: 600px;
|
6 |
+
max-height: 600px;
|
7 |
+
}
|
8 |
div#result {
|
9 |
max-width: 400px;
|
10 |
max-height: 400px;
|