fffiloni KingNish commited on
Commit
8ef457d
1 Parent(s): 4a91cdc

Added Option to select where to align the base image. (#7)

Browse files

- Added Option to select where to align the base image. (07baece5cb7768ed8cec71fed442f785bdda15c4)


Co-authored-by: Nishith Jain <[email protected]>

Files changed (1) hide show
  1. app.py +128 -57
app.py CHANGED
@@ -12,10 +12,6 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
15
- MODELS = {
16
- "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
17
- }
18
-
19
  config_file = hf_hub_download(
20
  "xinsir/controlnet-union-sdxl-1.0",
21
  filename="config_promax.json",
@@ -48,11 +44,20 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
48
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
49
 
50
 
 
 
 
 
 
 
 
 
 
51
  @spaces.GPU
52
- def infer(image, model_selection, width, height, overlap_width, num_inference_steps, prompt_input=None):
 
53
  source = image
54
  target_size = (width, height)
55
- target_ratio = (width, height) # Calculate aspect ratio from width and height
56
  overlap = overlap_width
57
 
58
  # Upscale if source is smaller than target in both dimensions
@@ -68,25 +73,63 @@ def infer(image, model_selection, width, height, overlap_width, num_inference_st
68
  new_height = int(source.height * scale_factor)
69
  source = source.resize((new_width, new_height), Image.LANCZOS)
70
 
71
- margin_x = (target_size[0] - source.width) // 2
72
- margin_y = (target_size[1] - source.height) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  background = Image.new('RGB', target_size, (255, 255, 255))
75
  background.paste(source, (margin_x, margin_y))
76
 
77
  mask = Image.new('L', target_size, 255)
78
  mask_draw = ImageDraw.Draw(mask)
79
- mask_draw.rectangle([
80
- (margin_x + overlap, margin_y + overlap),
81
- (margin_x + source.width - overlap, margin_y + source.height - overlap)
82
- ], fill=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  cnet_image = background.copy()
85
  cnet_image.paste(0, (0, 0), mask)
86
 
87
- final_prompt = "high quality"
88
- if prompt_input.strip() != "":
89
- final_prompt += ", " + prompt_input
90
 
91
  (
92
  prompt_embeds,
@@ -110,7 +153,14 @@ def infer(image, model_selection, width, height, overlap_width, num_inference_st
110
 
111
  yield background, cnet_image
112
 
 
 
 
 
 
 
113
  def preload_presets(target_ratio):
 
114
  if target_ratio == "9:16":
115
  changed_width = 720
116
  changed_height = 1280
@@ -122,9 +172,6 @@ def preload_presets(target_ratio):
122
  elif target_ratio == "Custom":
123
  return 720, 1280, gr.update(open=True)
124
 
125
- def clear_result():
126
- return gr.update(value=None)
127
-
128
 
129
  css = """
130
  .gradio-container {
@@ -152,63 +199,64 @@ with gr.Blocks(css=css) as demo:
152
  with gr.Column():
153
  input_image = gr.Image(
154
  type="pil",
155
- label="Input Image",
156
- sources=["upload"],
157
- height = 300
158
  )
159
-
160
- prompt_input = gr.Textbox(label="Prompt (Optional)")
161
-
 
 
 
 
162
  with gr.Row():
163
  target_ratio = gr.Radio(
164
- label = "Expected Ratio",
165
- choices = ["9:16", "16:9", "Custom"],
166
- value = "9:16",
167
- scale = 2
168
  )
169
 
170
- run_button = gr.Button("Generate", scale=1)
 
 
 
 
171
 
172
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
173
- with gr.Column():
174
  with gr.Row():
175
  width_slider = gr.Slider(
176
  label="Width",
177
  minimum=720,
178
- maximum=1440,
179
  step=8,
180
  value=720, # Set a default value
181
  )
182
  height_slider = gr.Slider(
183
  label="Height",
184
  minimum=720,
185
- maximum=1440,
186
  step=8,
187
  value=1280, # Set a default value
188
  )
189
  with gr.Row():
190
- model_selection = gr.Dropdown(
191
- choices=list(MODELS.keys()),
192
- value="RealVisXL V5.0 Lightning",
193
- label="Model",
 
 
 
194
  )
195
- num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8 )
196
-
197
- overlap_width = gr.Slider(
198
- label="Mask overlap width",
199
- minimum=1,
200
- maximum=50,
201
- value=42,
202
- step=1
203
- )
204
-
205
  gr.Examples(
206
  examples=[
207
- ["./examples/example_1.webp", "RealVisXL V5.0 Lightning", 1280, 720],
208
- ["./examples/example_2.jpg", "RealVisXL V5.0 Lightning", 720, 1280],
209
- ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024],
 
210
  ],
211
- inputs=[input_image, model_selection, width_slider, height_slider],
212
  )
213
 
214
  with gr.Column():
@@ -216,21 +264,38 @@ with gr.Blocks(css=css) as demo:
216
  interactive=False,
217
  label="Generated Image",
218
  )
 
 
 
 
 
219
 
 
 
 
 
 
 
220
  target_ratio.change(
221
- fn = preload_presets,
222
- inputs = [target_ratio],
223
- outputs = [width_slider, height_slider, settings_panel],
224
- queue = False
225
  )
 
226
  run_button.click(
227
  fn=clear_result,
228
  inputs=None,
229
  outputs=result,
230
  ).then(
231
  fn=infer,
232
- inputs=[input_image, model_selection, width_slider, height_slider, overlap_width, num_inference_steps, prompt_input],
 
233
  outputs=result,
 
 
 
 
234
  )
235
 
236
  prompt_input.submit(
@@ -239,8 +304,14 @@ with gr.Blocks(css=css) as demo:
239
  outputs=result,
240
  ).then(
241
  fn=infer,
242
- inputs=[input_image, model_selection, width_slider, height_slider, overlap_width, num_inference_steps, prompt_input],
 
243
  outputs=result,
 
 
 
 
244
  )
245
 
246
- demo.queue(max_size=12).launch(share=False, show_error=True, show_api=False)
 
 
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
 
 
 
 
15
  config_file = hf_hub_download(
16
  "xinsir/controlnet-union-sdxl-1.0",
17
  filename="config_promax.json",
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
 
47
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
48
+ """Checks if the image can be expanded based on the alignment."""
49
+ if alignment in ("Left", "Right") and source_width >= target_width:
50
+ return False
51
+ if alignment in ("Top", "Bottom") and source_height >= target_height:
52
+ return False
53
+ return True
54
+
55
+
56
  @spaces.GPU
57
+ def infer(image, width, height, overlap_width, num_inference_steps, prompt_input=None, alignment="Middle"):
58
+
59
  source = image
60
  target_size = (width, height)
 
61
  overlap = overlap_width
62
 
63
  # Upscale if source is smaller than target in both dimensions
 
73
  new_height = int(source.height * scale_factor)
74
  source = source.resize((new_width, new_height), Image.LANCZOS)
75
 
76
+ if not can_expand(source.width, source.height, target_size[0], target_size[1], alignment):
77
+ alignment = "Middle"
78
+
79
+ # Calculate margins based on alignment
80
+ if alignment == "Middle":
81
+ margin_x = (target_size[0] - source.width) // 2
82
+ margin_y = (target_size[1] - source.height) // 2
83
+ elif alignment == "Left":
84
+ margin_x = 0
85
+ margin_y = (target_size[1] - source.height) // 2
86
+ elif alignment == "Right":
87
+ margin_x = target_size[0] - source.width
88
+ margin_y = (target_size[1] - source.height) // 2
89
+ elif alignment == "Top":
90
+ margin_x = (target_size[0] - source.width) // 2
91
+ margin_y = 0
92
+ elif alignment == "Bottom":
93
+ margin_x = (target_size[0] - source.width) // 2
94
+ margin_y = target_size[1] - source.height
95
 
96
  background = Image.new('RGB', target_size, (255, 255, 255))
97
  background.paste(source, (margin_x, margin_y))
98
 
99
  mask = Image.new('L', target_size, 255)
100
  mask_draw = ImageDraw.Draw(mask)
101
+
102
+ # Adjust mask generation based on alignment
103
+ if alignment == "Middle":
104
+ mask_draw.rectangle([
105
+ (margin_x + overlap, margin_y + overlap),
106
+ (margin_x + source.width - overlap, margin_y + source.height - overlap)
107
+ ], fill=0)
108
+ elif alignment == "Left":
109
+ mask_draw.rectangle([
110
+ (margin_x, margin_y),
111
+ (margin_x + source.width - overlap, margin_y + source.height)
112
+ ], fill=0)
113
+ elif alignment == "Right":
114
+ mask_draw.rectangle([
115
+ (margin_x + overlap, margin_y),
116
+ (margin_x + source.width, margin_y + source.height)
117
+ ], fill=0)
118
+ elif alignment == "Top":
119
+ mask_draw.rectangle([
120
+ (margin_x, margin_y),
121
+ (margin_x + source.width, margin_y + source.height - overlap)
122
+ ], fill=0)
123
+ elif alignment == "Bottom":
124
+ mask_draw.rectangle([
125
+ (margin_x, margin_y + overlap),
126
+ (margin_x + source.width, margin_y + source.height)
127
+ ], fill=0)
128
 
129
  cnet_image = background.copy()
130
  cnet_image.paste(0, (0, 0), mask)
131
 
132
+ final_prompt = f"{prompt_input} , high quality, 4k"
 
 
133
 
134
  (
135
  prompt_embeds,
 
153
 
154
  yield background, cnet_image
155
 
156
+
157
+ def clear_result():
158
+ """Clears the result ImageSlider."""
159
+ return gr.update(value=None)
160
+
161
+
162
  def preload_presets(target_ratio):
163
+ """Updates the width and height sliders based on the selected aspect ratio."""
164
  if target_ratio == "9:16":
165
  changed_width = 720
166
  changed_height = 1280
 
172
  elif target_ratio == "Custom":
173
  return 720, 1280, gr.update(open=True)
174
 
 
 
 
175
 
176
  css = """
177
  .gradio-container {
 
199
  with gr.Column():
200
  input_image = gr.Image(
201
  type="pil",
202
+ label="Input Image"
 
 
203
  )
204
+
205
+ with gr.Row():
206
+ with gr.Column(scale=2):
207
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
208
+ with gr.Column(scale=1):
209
+ run_button = gr.Button("Generate")
210
+
211
  with gr.Row():
212
  target_ratio = gr.Radio(
213
+ label="Expected Ratio",
214
+ choices=["9:16", "16:9", "Custom"],
215
+ value="9:16",
216
+ scale=2
217
  )
218
 
219
+ alignment_dropdown = gr.Dropdown(
220
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
221
+ value="Middle",
222
+ label="Alignment"
223
+ )
224
 
225
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
226
+ with gr.Column():
227
  with gr.Row():
228
  width_slider = gr.Slider(
229
  label="Width",
230
  minimum=720,
231
+ maximum=1536,
232
  step=8,
233
  value=720, # Set a default value
234
  )
235
  height_slider = gr.Slider(
236
  label="Height",
237
  minimum=720,
238
+ maximum=1536,
239
  step=8,
240
  value=1280, # Set a default value
241
  )
242
  with gr.Row():
243
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
244
+ overlap_width = gr.Slider(
245
+ label="Mask overlap width",
246
+ minimum=1,
247
+ maximum=50,
248
+ value=42,
249
+ step=1
250
  )
251
+
 
 
 
 
 
 
 
 
 
252
  gr.Examples(
253
  examples=[
254
+ ["./examples/example_1.webp", 1280, 720, "Middle"],
255
+ ["./examples/example_2.jpg", 1440, 810, "Left"],
256
+ ["./examples/example_3.jpg", 1024, 1024, "Top"],
257
+ ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
258
  ],
259
+ inputs=[input_image, width_slider, height_slider, alignment_dropdown],
260
  )
261
 
262
  with gr.Column():
 
264
  interactive=False,
265
  label="Generated Image",
266
  )
267
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
268
+
269
+ def use_output_as_input(output_image):
270
+ """Sets the generated output as the new input image."""
271
+ return gr.update(value=output_image[1])
272
 
273
+ use_as_input_button.click(
274
+ fn=use_output_as_input,
275
+ inputs=[result],
276
+ outputs=[input_image]
277
+ )
278
+
279
  target_ratio.change(
280
+ fn=preload_presets,
281
+ inputs=[target_ratio],
282
+ outputs=[width_slider, height_slider, settings_panel],
283
+ queue=False
284
  )
285
+
286
  run_button.click(
287
  fn=clear_result,
288
  inputs=None,
289
  outputs=result,
290
  ).then(
291
  fn=infer,
292
+ inputs=[input_image, width_slider, height_slider, overlap_width, num_inference_steps,
293
+ prompt_input, alignment_dropdown],
294
  outputs=result,
295
+ ).then(
296
+ fn=lambda: gr.update(visible=True),
297
+ inputs=None,
298
+ outputs=use_as_input_button,
299
  )
300
 
301
  prompt_input.submit(
 
304
  outputs=result,
305
  ).then(
306
  fn=infer,
307
+ inputs=[input_image, width_slider, height_slider, overlap_width, num_inference_steps,
308
+ prompt_input, alignment_dropdown],
309
  outputs=result,
310
+ ).then(
311
+ fn=lambda: gr.update(visible=True),
312
+ inputs=None,
313
+ outputs=use_as_input_button,
314
  )
315
 
316
+
317
+ demo.queue(max_size=12).launch(share=False)