Pusheen commited on
Commit
003f3f9
1 Parent(s): b5b5e7b

Upload 8 files

Browse files
Files changed (7) hide show
  1. .gitattributes +0 -1
  2. .gitignore +112 -0
  3. README.md +4 -4
  4. __init__.py +0 -0
  5. app.py +590 -477
  6. environment.yaml +29 -0
  7. requirements.txt +15 -11
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IntelliJ project files
2
+ .idea
3
+ *.iml
4
+ out
5
+ gen
6
+
7
+ ### Vim template
8
+ [._]*.s[a-w][a-z]
9
+ [._]s[a-w][a-z]
10
+ *.un~
11
+ Session.vim
12
+ .netrwhist
13
+ *~
14
+
15
+ ### IPythonNotebook template
16
+ # Temporary data
17
+ .ipynb_checkpoints/
18
+
19
+ ### Python template
20
+ # Byte-compiled / optimized / DLL files
21
+ __pycache__/
22
+ *.py[cod]
23
+ *$py.class
24
+
25
+ # C extensions
26
+ *.so
27
+
28
+ # Distribution / packaging
29
+ .Python
30
+ env/
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ #lib/
38
+ #lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *,cover
65
+
66
+ # Translations
67
+ *.mo
68
+ *.pot
69
+
70
+ # Django stuff:
71
+ *.log
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ *.ipynb
80
+ *.params
81
+ # *.json
82
+ .vscode/
83
+ *.code-workspace/
84
+
85
+ lib/pycocotools/_mask.c
86
+ lib/nms/cpu_nms.c
87
+
88
+ OUTPUT
89
+ OUTPUT/*
90
+ models/*
91
+ DATASET
92
+ DATASET/*
93
+ external/
94
+ MODELS
95
+ MODELS/*
96
+ gradio_cached_examples/*
97
+
98
+ kill.sh
99
+
100
+ draws/
101
+ #:wq
102
+ #plot/figs
103
+
104
+ *venv/*
105
+
106
+ # images
107
+ # images/*
108
+
109
+ create_samples/
110
+ create_samples/*
111
+
112
+ ckpts/*
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: LoCo
3
- emoji: 🐠
4
- colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: LoCo_Gligen Demo
3
+ emoji: 👁
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,37 +1,164 @@
1
-
2
  import gradio as gr
3
  import torch
4
- from transformers import CLIPTextModel, CLIPTokenizer
5
- from diffusers import AutoencoderKL, DDIMScheduler
6
- from my_model import unet_2d_condition
7
  import json
8
  import numpy as np
9
  from PIL import Image, ImageDraw, ImageFont
10
  from functools import partial
 
11
  import math
12
- from utils import compute_loco_v2
 
13
  from gradio import processing_utils
14
  from typing import Optional
15
- from typing import List
16
 
17
  import warnings
18
- import string
19
 
20
- import sys
 
 
 
21
 
 
22
  sys.tracebacklimit = 0
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class Blocks(gr.Blocks):
25
 
26
  def __init__(
27
- self,
28
- theme: str = "default",
29
- analytics_enabled: Optional[bool] = None,
30
- mode: str = "blocks",
31
- title: str = "Gradio",
32
- css: Optional[str] = None,
33
- **kwargs,
34
  ):
 
35
  self.extra_configs = {
36
  'thumbnail': kwargs.pop('thumbnail', ''),
37
  'url': kwargs.pop('url', 'https://gradio.app/'),
@@ -46,9 +173,82 @@ class Blocks(gr.Blocks):
46
 
47
  for k, v in self.extra_configs.items():
48
  config[k] = v
49
-
50
  return config
51
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def draw_box(boxes=[], texts=[], img=None):
53
  if len(boxes) == 0 and img is None:
54
  return None
@@ -58,111 +258,13 @@ def draw_box(boxes=[], texts=[], img=None):
58
  colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
59
  draw = ImageDraw.Draw(img)
60
  font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
61
- print(boxes)
62
  for bid, box in enumerate(boxes):
63
  draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
64
  anno_text = texts[bid]
65
- draw.rectangle(
66
- [box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]],
67
- outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
68
- draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font,
69
- fill=(255, 255, 255))
70
  return img
71
 
72
- '''
73
- inference model
74
- '''
75
-
76
- def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
77
- uncond_input = tokenizer(
78
- ["lowres, bad anatomy, bad hands, bad faces, text, error, missing fingers, extra digit, fewer digits, \
79
- cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
80
- )
81
-
82
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
83
-
84
- input_ids = tokenizer(
85
- prompt,
86
- padding="max_length",
87
- truncation=True,
88
- max_length=tokenizer.model_max_length,
89
- return_tensors="pt",
90
- ).input_ids[0].unsqueeze(0).to(device)
91
- # text_embeddings = text_encoder(input_ids)[0]
92
- text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
93
- # text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
94
- generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
95
-
96
- latents = torch.randn(
97
- (batch_size, 4, 64, 64),
98
- generator=generator,
99
- ).to(device)
100
-
101
- # noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
102
- noise_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
103
-
104
- # generator = torch.Generator("cuda").manual_seed(1024)
105
- noise_scheduler.set_timesteps(50)
106
-
107
- latents = latents * noise_scheduler.init_noise_sigma
108
-
109
- loss = torch.tensor(10000)
110
-
111
- for index, t in enumerate(noise_scheduler.timesteps):
112
- iteration = 0
113
-
114
- while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
115
- latents = latents.requires_grad_(True)
116
-
117
- # latent_model_input = torch.cat([latents] * 2)
118
- latent_model_input = latents
119
-
120
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
121
- noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
122
- unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
123
-
124
- # update latents with guidence from gaussian blob
125
-
126
- loss = compute_loco_v2(attn_map_integrated_down, attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
127
- object_positions=object_positions) * loss_scale
128
-
129
- # print(loss.item() / loss_scale)
130
-
131
- grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
132
-
133
- latents = latents - grad_cond
134
- iteration += 1
135
- torch.cuda.empty_cache()
136
- torch.cuda.empty_cache()
137
-
138
-
139
- with torch.no_grad():
140
-
141
- latent_model_input = torch.cat([latents] * 2)
142
-
143
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
144
- noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
145
- unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
146
-
147
- noise_pred = noise_pred.sample
148
-
149
- # perform classifier-free guidance
150
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
151
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
152
-
153
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
154
- torch.cuda.empty_cache()
155
- # Decode image
156
- with torch.no_grad():
157
- # print("decode image")
158
- latents = 1 / 0.18215 * latents
159
- image = vae.decode(latents).sample
160
- image = (image / 2 + 0.5).clamp(0, 1)
161
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
162
- images = (image * 255).round().astype("uint8")
163
- pil_images = [Image.fromarray(image) for image in images]
164
- return pil_images
165
-
166
  def get_concat(ims):
167
  if len(ims) == 1:
168
  n_col = 1
@@ -177,94 +279,22 @@ def get_concat(ims):
177
  return dst
178
 
179
 
180
- def click_on_display(language_instruction, grounding_texts, sketch_pad,
181
- loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
182
- state):
183
- if 'boxes' not in state:
184
- state['boxes'] = []
185
- boxes = state['boxes']
186
- x = Image.open('./images/dog.png')
187
- gen_images = [gr.Image.update(value=x, visible=True)]
188
 
189
- return gen_images + [state]
190
 
191
- def Pharse2idx(prompt, phrases):
192
- phrases = [x.strip() for x in phrases.split(';')]
193
- print('phrases', phrases)
194
-
195
- punc_string = string.punctuation
196
- # for punc in [',', '.', ';', ':', '?', '!']:
197
- for punc in punc_string:
198
- prompt = prompt.replace(punc, ' ')
199
- print('clear pp:', prompt)
200
- prompt_list = prompt.strip('.').replace(',', '').split(' ')
201
-
202
- print('prompt_list', prompt_list)
203
- object_positions = []
204
- for obj in phrases:
205
- obj_position = []
206
- for word in obj.split(' '):
207
- print('word', word)
208
- obj_first_index = prompt_list.index(word) + 1
209
- obj_position.append(obj_first_index)
210
- object_positions.append(obj_position)
211
- print('object_positions', object_positions)
212
- return object_positions
213
-
214
-
215
- def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
216
- loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
217
- state):
218
- # language_inst: prompt; grounding_texts: phrases
219
- if 'boxes' not in state:
220
- state['boxes'] = []
221
- boxes = state['boxes']
222
 
223
- # print('raw grounding texts:', grounding_texts)
224
- language_instruction= language_instruction.lower()
225
- phrases = grounding_texts.lower()
226
- # print('got phrases!')
227
- # grounding_texts = [x.strip() for x in grounding_texts.split(';')]
228
- # print('new grd texts:',grounding_texts)
229
-
230
- # # assert len(boxes) == len(grounding_texts)
231
- # if len(boxes) != len(grounding_texts):
232
- # if len(boxes) < len(grounding_texts):
233
- # raise ValueError("""The number of boxes should be equal to the number of grounding objects.
234
- # Number of boxes drawn: {}, number of grounding tokens: {}.
235
- # Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
236
- # grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
237
 
238
- boxes = (np.asarray(boxes) / 512).tolist()
239
- boxes = [[box] for box in boxes]
240
- # grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
241
- # language_instruction_list = language_instruction.strip('.').split(' ')
242
- # object_positions = []
243
- # for obj in grounding_texts:
244
- # obj_position = []
245
- # for word in obj.split(' '):
246
- # obj_first_index = language_instruction_list.index(word) + 1
247
- # obj_position.append(obj_first_index)
248
- # object_positions.append(obj_position)
249
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
250
-
251
- print('getting obj positions!')
252
- object_positions = Pharse2idx(language_instruction, phrases)
253
-
254
- gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
255
-
256
- blank_samples = batch_size % 2 if batch_size > 1 else 0
257
- gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
258
- + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
259
- + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
260
-
261
- return gen_images + [state]
262
-
263
- def generate_legacy(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
264
- loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
265
  state):
266
  if 'boxes' not in state:
267
  state['boxes'] = []
 
268
  boxes = state['boxes']
269
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
270
  # assert len(boxes) == len(grounding_texts)
@@ -276,24 +306,49 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun
276
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
277
 
278
  boxes = (np.asarray(boxes) / 512).tolist()
279
- boxes = [[box] for box in boxes]
280
- grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
281
- language_instruction_list = language_instruction.strip('.').split(' ')
282
- object_positions = []
283
- for obj in grounding_texts:
284
- obj_position = []
285
- for word in obj.split(' '):
286
- obj_first_index = language_instruction_list.index(word) + 1
287
- obj_position.append(obj_first_index)
288
- object_positions.append(obj_position)
289
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
290
-
291
- gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  blank_samples = batch_size % 2 if batch_size > 1 else 0
294
- gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
295
- + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
296
- + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
297
 
298
  return gen_images + [state]
299
 
@@ -301,32 +356,28 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun
301
  def binarize(x):
302
  return (x != 0).astype('uint8') * 255
303
 
304
-
305
  def sized_center_crop(img, cropx, cropy):
306
  y, x = img.shape[:2]
307
  startx = x // 2 - (cropx // 2)
308
- starty = y // 2 - (cropy // 2)
309
- return img[starty:starty + cropy, startx:startx + cropx]
310
-
311
 
312
  def sized_center_fill(img, fill, cropx, cropy):
313
  y, x = img.shape[:2]
314
  startx = x // 2 - (cropx // 2)
315
- starty = y // 2 - (cropy // 2)
316
- img[starty:starty + cropy, startx:startx + cropx] = fill
317
  return img
318
 
319
-
320
  def sized_center_mask(img, cropx, cropy):
321
  y, x = img.shape[:2]
322
  startx = x // 2 - (cropx // 2)
323
- starty = y // 2 - (cropy // 2)
324
- center_region = img[starty:starty + cropy, startx:startx + cropx].copy()
325
  img = (img * 0.2).astype('uint8')
326
- img[starty:starty + cropy, startx:startx + cropx] = center_region
327
  return img
328
 
329
-
330
  def center_crop(img, HW=None, tgt_size=(512, 512)):
331
  if HW is None:
332
  H, W = img.shape[:2]
@@ -336,27 +387,56 @@ def center_crop(img, HW=None, tgt_size=(512, 512)):
336
  img = img.resize(tgt_size)
337
  return np.array(img)
338
 
339
-
340
- def draw(input, grounding_texts, new_image_trigger, state):
341
  if type(input) == dict:
342
  image = input['image']
343
  mask = input['mask']
344
  else:
345
  mask = input
 
346
  if mask.ndim == 3:
347
- mask = 255 - mask[..., 0]
348
 
349
  image_scale = 1.0
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  mask = binarize(mask)
352
 
353
  if type(mask) != np.ndarray:
354
  mask = np.array(mask)
355
 
356
- if mask.sum() == 0:
357
  state = {}
358
 
359
- image = None
 
 
 
360
 
361
  if 'boxes' not in state:
362
  state['boxes'] = []
@@ -385,277 +465,310 @@ def draw(input, grounding_texts, new_image_trigger, state):
385
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
386
  grounding_texts = [x for x in grounding_texts if len(x) > 0]
387
  if len(grounding_texts) < len(state['boxes']):
388
- grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
 
389
  box_image = draw_box(state['boxes'], grounding_texts, image)
390
 
391
- return [box_image, new_image_trigger, image_scale, state]
 
 
 
 
392
 
 
393
 
394
  def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
395
  if task != 'Grounded Inpainting':
396
  sketch_pad_trigger = sketch_pad_trigger + 1
397
  blank_samples = batch_size % 2 if batch_size > 1 else 0
398
- out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
399
- # state = {}
400
- return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
401
-
402
-
403
- def main():
404
-
405
- css = """
406
-
407
- #component-0 {
408
- max-width: 550px;
409
- margin: auto;
410
- padding-top: 1.5rem;
411
- }
412
- #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
413
- {
414
- height: var(--height) !important;
415
- max-height: var(--height) !important;
416
- min-height: var(--height) !important;
417
- }
418
- #paper-info a {
419
- color:#008AD7;
420
- text-decoration: none;
421
- }
422
- #paper-info a:hover {
423
- cursor: pointer;
424
- text-decoration: none;
425
- }
426
- .container {
427
- max-width: 550px;
428
- margin: auto;
429
- padding-top: 1.5rem;
430
- }
431
- .tooltip {
432
- color: #555;
433
- position: relative;
434
- display: inline-block;
435
- cursor: pointer;
436
- }
437
-
438
- .tooltip .tooltiptext {
439
- visibility: hidden;
440
- width: 400px;
441
- background-color: #555;
442
- color: #fff;
443
- text-align: center;
444
- padding: 5px;
445
- border-radius: 5px;
446
- position: absolute;
447
- z-index: 1; /* Set z-index to 1 */
448
- left: 10px;
449
- top: 100%;
450
- opacity: 0;
451
- transition: opacity 0.3s;
452
- }
453
-
454
- .tooltip:hover .tooltiptext {
455
- visibility: visible;
456
- opacity: 1;
457
- z-index: 9999; /* Set a high z-index value when hovering */
458
- }
459
-
460
-
461
  """
 
462
 
463
- rescale_js = """
464
- function(x) {
465
- const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
466
- let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
467
- const image_width = root.querySelector('#img2img_image').clientWidth;
468
- const target_height = parseInt(image_width * image_scale);
469
- document.body.style.setProperty('--height', `${target_height}px`);
470
- root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
471
- root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
472
- return x;
473
- }
474
- """
475
- with open('./conf/unet/config.json') as f:
476
- unet_config = json.load(f)
477
-
478
- sd_path = "runwayml/stable-diffusion-v1-5"
479
- unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained(sd_path,
480
- subfolder="unet")
481
- tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
482
- text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder")
483
- vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
484
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
485
- unet.to(device)
486
- text_encoder.to(device)
487
- vae.to(device)
488
-
489
- with Blocks(
490
- css=css,
491
- analytics_enabled=False,
492
- title="LoCo: Locally Constrained Training-free Layout-to-Image Generation",
493
- ) as demo:
494
- description = """<p style="text-align: center; font-weight: bold;">
495
- <span style="font-size: 28px">LoCo: Locally Constrained Training-free Layout-to-Image Generation</span>
496
- <br>
497
- <span style="font-size: 18px" id="paper-info">
498
- [<a href="https://peiang-zhao.tech/LoCo/" target="_blank">Project Page</a>]
499
- [<a href="https://arxiv.org/pdf/2311.12342" target="_blank">Paper</a>]
500
- [<a href=" " target="_blank">GitHub</a>]
501
- </span>
502
- <p>Tips:
503
- <ul>
504
- <li>You can change the 'random seed' in 'Advanced Options' below to generate various images. </li>
505
- <li>Layouts with many small bounding boxes may lead to unpleasant results. It's a tough setting for training free methods like LoCo. </li>
506
- <li>Generate an image on A10G takes ~25 seconds. Upgrade the space's GPU for faster inference. :P </li>
507
- </ul>
508
- </p>
509
- """
510
- gr.HTML(description)
511
- with gr.Column():
512
- language_instruction = gr.Textbox(
513
- label="Text Prompt (e.g., a dog and a car)",
514
- )
515
- grounding_instruction = gr.Textbox(
516
- label="Grounding instruction (Separated by semicolon, e.g., dog;car)",
517
- )
518
  sketch_pad_trigger = gr.Number(value=0, visible=False)
519
  sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
520
  init_white_trigger = gr.Number(value=0, visible=False)
521
  image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
522
  new_image_trigger = gr.Number(value=0, visible=False)
523
 
524
-
525
- with gr.Row():
526
- sketch_pad = gr.Paint(label="Sketch Pad", elem_id="img2img_image", source='canvas', shape=(512, 512))
 
 
 
 
 
 
 
 
 
527
  with gr.Row():
528
- # sketch_pad = gr.Image(source='canvas', tool='sketch', size=(512, 512))
529
  out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
530
- with gr.Row():
531
- out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
532
-
533
  with gr.Row():
534
  clear_btn = gr.Button(value='Clear')
535
  gen_btn = gr.Button(value='Generate')
536
-
537
  with gr.Accordion("Advanced Options", open=False):
538
  with gr.Column():
539
- description = """<div class="tooltip">Loss Scale Factor &#9432
540
- <span class="tooltiptext">The scale factor of the constraints. The larger it is, the better control we get while it sometimes losses fidelity. </span>
541
- </div>
542
- <div class="tooltip">CFG Guidance Scale &#9432
543
- <span class="tooltiptext">The scale factor of classifier-free guidance. </span>
544
- </div>
545
- <div class="tooltip" >Max Iteration per Step &#9432
546
- <span class="tooltiptext">The max iterations of applying constraints in each diffusion inference process.</span>
547
- </div>
548
- <div class="tooltip" >Loss Threshold &#9432
549
- <span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the guidance is stopped. </span>
550
- </div>
551
- <div class="tooltip" >Max Step of Backward Guidance &#9432
552
- <span class="tooltiptext">The max steps of guidance in diffusion inference process.</span>
553
- </div>
554
- """
555
- gr.HTML(description)
556
- Loss_scale = gr.Slider(minimum=0, maximum=200, step=5, value=50,label="Loss Scale Factor")
557
- guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="CFG Guidance Scale")
558
- batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
559
- max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
560
- loss_threshold = gr.Slider(minimum=0, maximum=0.2, step=0.001, value=0.002, label="Loss Threshold")
561
- max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Guidance")
562
- rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="Random Seed")
563
-
564
- state = gr.State({})
565
-
566
-
567
- class Controller:
568
- def __init__(self):
569
- self.calls = 0
570
- self.tracks = 0
571
- self.resizes = 0
572
- self.scales = 0
573
-
574
- def init_white(self, init_white_trigger):
575
- self.calls += 1
576
- return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
577
-
578
- def change_n_samples(self, n_samples):
579
- blank_samples = n_samples % 2 if n_samples > 1 else 0
580
- return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
581
- + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
582
-
583
-
584
- controller = Controller()
585
- demo.load(
586
- lambda x: x + 1,
587
- inputs=sketch_pad_trigger,
588
- outputs=sketch_pad_trigger,
589
- queue=False)
590
- sketch_pad.edit(
591
- draw,
592
- inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
593
- outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
594
- queue=False,
595
- )
596
- grounding_instruction.change(
597
- draw,
598
- inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
599
- outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
600
- queue=False,
601
- )
602
- clear_btn.click(
603
- clear,
604
- inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
605
- outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
606
- queue=False)
607
-
608
- sketch_pad_trigger.change(
609
- controller.init_white,
610
- inputs=[init_white_trigger],
611
- outputs=[sketch_pad, image_scale, init_white_trigger],
612
- queue=False)
613
-
614
- gen_btn.click(
615
- fn=partial(generate, unet, vae, tokenizer, text_encoder,),
616
- inputs=[
617
- language_instruction, grounding_instruction, sketch_pad,
618
- loss_threshold, guidance_scale, batch_size, rand_seed,
619
- max_step,
620
- Loss_scale, max_iter,
621
- state,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
  ],
623
- outputs=[out_gen_1, state],
624
- queue=True
625
- )
626
- sketch_pad_resize_trigger.change(
627
- None,
628
- None,
629
- sketch_pad_resize_trigger,
630
- _js=rescale_js,
631
- queue=False)
632
- init_white_trigger.change(
633
- None,
634
- None,
635
- init_white_trigger,
636
- _js=rescale_js,
637
- queue=False)
638
-
639
- with gr.Column():
640
- gr.Examples(
641
- examples=[
642
- [
643
- # "images/input.png",
644
- "An airplane and a chair on the grassland.",
645
- "airplane;chair",
646
- "images/airplane_chair.png"
647
- ],
648
  ],
649
- inputs=[language_instruction, grounding_instruction, out_gen_1],
650
- outputs=None,
651
- fn=None,
652
- cache_examples=False,
653
- )
654
- description = """<p> Some source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a> and <a href="https://huggingface.co/spaces/silentchen/layout-guidance">Layout-guidance</a>. Thanks! </p>"""
655
- gr.HTML(description)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
657
- demo.queue(concurrency_count=1, api_open=False)
658
- demo.launch(share=False, show_api=False, show_error=True)
659
 
660
- if __name__ == '__main__':
661
- main()
 
 
1
  import gradio as gr
2
  import torch
3
+ from omegaconf import OmegaConf
4
+ from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
5
+
6
  import json
7
  import numpy as np
8
  from PIL import Image, ImageDraw, ImageFont
9
  from functools import partial
10
+ from collections import Counter
11
  import math
12
+ import gc
13
+
14
  from gradio import processing_utils
15
  from typing import Optional
 
16
 
17
  import warnings
 
18
 
19
+ from datetime import datetime
20
+
21
+ from huggingface_hub import hf_hub_download
22
+ hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
23
 
24
+ import sys
25
  sys.tracebacklimit = 0
26
 
27
+
28
+ def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
29
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
30
+ return torch.load(cache_file, map_location='cpu')
31
+
32
+ def load_ckpt_config_from_hf(modality):
33
+ ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
34
+ config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
35
+ return ckpt, config
36
+
37
+
38
+ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
39
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
40
+ config = OmegaConf.create( config["_content"] ) # config used in training
41
+ config.alpha_scale = 1.0
42
+ config.model['params']['is_inpaint'] = is_inpaint
43
+ config.model['params']['is_style'] = is_style
44
+
45
+ if common_instances is None:
46
+ common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
47
+ common_instances = load_common_ckpt(config, common_ckpt)
48
+
49
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
50
+
51
+ return loaded_model_list, common_instances
52
+
53
+
54
+ class Instance:
55
+ def __init__(self, capacity = 2):
56
+ self.model_type = 'base'
57
+ self.loaded_model_list = {}
58
+ self.counter = Counter()
59
+ self.global_counter = Counter()
60
+ self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
61
+ 'gligen-generation-text-box',
62
+ is_inpaint=False, is_style=False, common_instances=None
63
+ )
64
+ self.capacity = capacity
65
+
66
+ def _log(self, model_type, batch_size, instruction, phrase_list):
67
+ self.counter[model_type] += 1
68
+ self.global_counter[model_type] += 1
69
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
70
+ print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
71
+ current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
72
+ ))
73
+
74
+ def get_model(self, model_type, batch_size, instruction, phrase_list):
75
+ if model_type in self.loaded_model_list:
76
+ self._log(model_type, batch_size, instruction, phrase_list)
77
+ return self.loaded_model_list[model_type]
78
+
79
+ if self.capacity == len(self.loaded_model_list):
80
+ least_used_type = self.counter.most_common()[-1][0]
81
+ del self.loaded_model_list[least_used_type]
82
+ del self.counter[least_used_type]
83
+ gc.collect()
84
+ torch.cuda.empty_cache()
85
+
86
+ self.loaded_model_list[model_type] = self._get_model(model_type)
87
+ self._log(model_type, batch_size, instruction, phrase_list)
88
+ return self.loaded_model_list[model_type]
89
+
90
+ def _get_model(self, model_type):
91
+ if model_type == 'base':
92
+ return ckpt_load_helper(
93
+ 'gligen-generation-text-box',
94
+ is_inpaint=False, is_style=False, common_instances=self.common_instances
95
+ )[0]
96
+ elif model_type == 'inpaint':
97
+ return ckpt_load_helper(
98
+ 'gligen-inpainting-text-box',
99
+ is_inpaint=True, is_style=False, common_instances=self.common_instances
100
+ )[0]
101
+ elif model_type == 'style':
102
+ return ckpt_load_helper(
103
+ 'gligen-generation-text-image-box',
104
+ is_inpaint=False, is_style=True, common_instances=self.common_instances
105
+ )[0]
106
+
107
+ assert False
108
+
109
+ instance = Instance()
110
+
111
+
112
+ def load_clip_model():
113
+ from transformers import CLIPProcessor, CLIPModel
114
+ version = "openai/clip-vit-large-patch14"
115
+ model = CLIPModel.from_pretrained(version).cuda()
116
+ processor = CLIPProcessor.from_pretrained(version)
117
+
118
+ return {
119
+ 'version': version,
120
+ 'model': model,
121
+ 'processor': processor,
122
+ }
123
+
124
+ clip_model = load_clip_model()
125
+
126
+
127
+ class ImageMask(gr.components.Image):
128
+ """
129
+ Sets: source="canvas", tool="sketch"
130
+ """
131
+
132
+ is_template = True
133
+
134
+ def __init__(self, **kwargs):
135
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
136
+
137
+ def preprocess(self, x):
138
+ if x is None:
139
+ return x
140
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
141
+ decode_image = processing_utils.decode_base64_to_image(x)
142
+ width, height = decode_image.size
143
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
144
+ mask[..., -1] = 255
145
+ mask = self.postprocess(mask)
146
+ x = {'image': x, 'mask': mask}
147
+ return super().preprocess(x)
148
+
149
+
150
  class Blocks(gr.Blocks):
151
 
152
  def __init__(
153
+ self,
154
+ theme: str = "default",
155
+ analytics_enabled: Optional[bool] = None,
156
+ mode: str = "blocks",
157
+ title: str = "Gradio",
158
+ css: Optional[str] = None,
159
+ **kwargs,
160
  ):
161
+
162
  self.extra_configs = {
163
  'thumbnail': kwargs.pop('thumbnail', ''),
164
  'url': kwargs.pop('url', 'https://gradio.app/'),
 
173
 
174
  for k, v in self.extra_configs.items():
175
  config[k] = v
176
+
177
  return config
178
+
179
+ '''
180
+ inference model
181
+ '''
182
+
183
+ @torch.no_grad()
184
+ def inference(task, language_instruction, grounding_instruction, inpainting_boxes_nodrop, image,
185
+ alpha_sample, guidance_scale, batch_size,
186
+ fix_seed, rand_seed, actual_mask, style_image,
187
+ *args, **kwargs):
188
+ grounding_instruction = json.loads(grounding_instruction)
189
+ phrase_list, location_list = [], []
190
+ for k, v in grounding_instruction.items():
191
+ phrase_list.append(k)
192
+ location_list.append(v)
193
+
194
+ placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
195
+ image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
196
+
197
+ batch_size = int(batch_size)
198
+ if not 1 <= batch_size <= 4:
199
+ batch_size = 2
200
+
201
+ if style_image == None:
202
+ has_text_mask = 1
203
+ has_image_mask = 0 # then we hack above 'image_list'
204
+ else:
205
+ valid_phrase_len = len(phrase_list)
206
+
207
+ phrase_list += ['placeholder']
208
+ has_text_mask = [1]*valid_phrase_len + [0]
209
+
210
+ image_list = [placeholder_image]*valid_phrase_len + [style_image]
211
+ has_image_mask = [0]*valid_phrase_len + [1]
212
+
213
+ location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location
214
+
215
+ if task == 'Grounded Inpainting':
216
+ alpha_sample = 1.0
217
+
218
+ instruction = dict(
219
+ prompt = language_instruction,
220
+ phrases = phrase_list,
221
+ images = image_list,
222
+ locations = location_list,
223
+ alpha_type = [alpha_sample, 0, 1.0 - alpha_sample],
224
+ has_text_mask = has_text_mask,
225
+ has_image_mask = has_image_mask,
226
+ save_folder_name = language_instruction,
227
+ guidance_scale = guidance_scale,
228
+ batch_size = batch_size,
229
+ fix_seed = bool(fix_seed),
230
+ rand_seed = int(rand_seed),
231
+ actual_mask = actual_mask,
232
+ inpainting_boxes_nodrop = inpainting_boxes_nodrop,
233
+ )
234
+
235
+ get_model = partial(instance.get_model,
236
+ batch_size=batch_size,
237
+ instruction=language_instruction,
238
+ phrase_list=phrase_list)
239
+
240
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
241
+ if task == 'Grounded Generation':
242
+ if style_image == None:
243
+ return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
244
+ else:
245
+ return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
246
+ elif task == 'Grounded Inpainting':
247
+ assert image is not None
248
+ instruction['input_image'] = image.convert("RGB")
249
+ return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
250
+
251
+
252
  def draw_box(boxes=[], texts=[], img=None):
253
  if len(boxes) == 0 and img is None:
254
  return None
 
258
  colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
259
  draw = ImageDraw.Draw(img)
260
  font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
 
261
  for bid, box in enumerate(boxes):
262
  draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
263
  anno_text = texts[bid]
264
+ draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
265
+ draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255))
 
 
 
266
  return img
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  def get_concat(ims):
269
  if len(ims) == 1:
270
  n_col = 1
 
279
  return dst
280
 
281
 
282
+ def auto_append_grounding(language_instruction, grounding_texts):
283
+ for grounding_text in grounding_texts:
284
+ if grounding_text not in language_instruction and grounding_text != 'auto':
285
+ language_instruction += "; " + grounding_text
286
+ return language_instruction
 
 
 
287
 
 
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
+ def generate(task, language_instruction, grounding_texts, sketch_pad,
292
+ alpha_sample, guidance_scale, batch_size,
293
+ fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  state):
295
  if 'boxes' not in state:
296
  state['boxes'] = []
297
+
298
  boxes = state['boxes']
299
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
300
  # assert len(boxes) == len(grounding_texts)
 
306
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
307
 
308
  boxes = (np.asarray(boxes) / 512).tolist()
309
+ grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
310
+
311
+ image = None
312
+ actual_mask = None
313
+ if task == 'Grounded Inpainting':
314
+ image = state.get('original_image', sketch_pad['image']).copy()
315
+ image = center_crop(image)
316
+ image = Image.fromarray(image)
317
+
318
+ if use_actual_mask:
319
+ actual_mask = sketch_pad['mask'].copy()
320
+ if actual_mask.ndim == 3:
321
+ actual_mask = actual_mask[..., 0]
322
+ actual_mask = center_crop(actual_mask, tgt_size=(64, 64))
323
+ actual_mask = torch.from_numpy(actual_mask == 0).float()
324
+
325
+ if state.get('inpaint_hw', None):
326
+ boxes = np.asarray(boxes) * 0.9 + 0.05
327
+ boxes = boxes.tolist()
328
+ grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes) if obj != 'auto'})
329
+
330
+ if append_grounding:
331
+ language_instruction = auto_append_grounding(language_instruction, grounding_texts)
332
+
333
+ gen_images, gen_overlays = inference(
334
+ task, language_instruction, grounding_instruction, boxes, image,
335
+ alpha_sample, guidance_scale, batch_size,
336
+ fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
337
+ )
338
+
339
+ for idx, gen_image in enumerate(gen_images):
340
+
341
+ if task == 'Grounded Inpainting' and state.get('inpaint_hw', None):
342
+ hw = min(*state['original_image'].shape[:2])
343
+ gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw)
344
+ gen_image = Image.fromarray(gen_image)
345
+
346
+ gen_images[idx] = gen_image
347
 
348
  blank_samples = batch_size % 2 if batch_size > 1 else 0
349
+ gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
350
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
351
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
352
 
353
  return gen_images + [state]
354
 
 
356
  def binarize(x):
357
  return (x != 0).astype('uint8') * 255
358
 
 
359
  def sized_center_crop(img, cropx, cropy):
360
  y, x = img.shape[:2]
361
  startx = x // 2 - (cropx // 2)
362
+ starty = y // 2 - (cropy // 2)
363
+ return img[starty:starty+cropy, startx:startx+cropx]
 
364
 
365
  def sized_center_fill(img, fill, cropx, cropy):
366
  y, x = img.shape[:2]
367
  startx = x // 2 - (cropx // 2)
368
+ starty = y // 2 - (cropy // 2)
369
+ img[starty:starty+cropy, startx:startx+cropx] = fill
370
  return img
371
 
 
372
  def sized_center_mask(img, cropx, cropy):
373
  y, x = img.shape[:2]
374
  startx = x // 2 - (cropx // 2)
375
+ starty = y // 2 - (cropy // 2)
376
+ center_region = img[starty:starty+cropy, startx:startx+cropx].copy()
377
  img = (img * 0.2).astype('uint8')
378
+ img[starty:starty+cropy, startx:startx+cropx] = center_region
379
  return img
380
 
 
381
  def center_crop(img, HW=None, tgt_size=(512, 512)):
382
  if HW is None:
383
  H, W = img.shape[:2]
 
387
  img = img.resize(tgt_size)
388
  return np.array(img)
389
 
390
+ def draw(task, input, grounding_texts, new_image_trigger, state):
 
391
  if type(input) == dict:
392
  image = input['image']
393
  mask = input['mask']
394
  else:
395
  mask = input
396
+
397
  if mask.ndim == 3:
398
+ mask = mask[..., 0]
399
 
400
  image_scale = 1.0
401
 
402
+ # resize trigger
403
+ if task == "Grounded Inpainting":
404
+ mask_cond = mask.sum() == 0
405
+ # size_cond = mask.shape != (512, 512)
406
+ if mask_cond and 'original_image' not in state:
407
+ image = Image.fromarray(image)
408
+ width, height = image.size
409
+ scale = 600 / min(width, height)
410
+ image = image.resize((int(width * scale), int(height * scale)))
411
+ state['original_image'] = np.array(image).copy()
412
+ image_scale = float(height / width)
413
+ return [None, new_image_trigger + 1, image_scale, state]
414
+ else:
415
+ original_image = state['original_image']
416
+ H, W = original_image.shape[:2]
417
+ image_scale = float(H / W)
418
+
419
+ mask = binarize(mask)
420
+ if mask.shape != (512, 512):
421
+ # assert False, "should not receive any non- 512x512 masks."
422
+ if 'original_image' in state and state['original_image'].shape[:2] == mask.shape:
423
+ mask = center_crop(mask, state['inpaint_hw'])
424
+ image = center_crop(state['original_image'], state['inpaint_hw'])
425
+ else:
426
+ mask = np.zeros((512, 512), dtype=np.uint8)
427
+ # mask = center_crop(mask)
428
  mask = binarize(mask)
429
 
430
  if type(mask) != np.ndarray:
431
  mask = np.array(mask)
432
 
433
+ if mask.sum() == 0 and task != "Grounded Inpainting":
434
  state = {}
435
 
436
+ if task != 'Grounded Inpainting':
437
+ image = None
438
+ else:
439
+ image = Image.fromarray(image)
440
 
441
  if 'boxes' not in state:
442
  state['boxes'] = []
 
465
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
466
  grounding_texts = [x for x in grounding_texts if len(x) > 0]
467
  if len(grounding_texts) < len(state['boxes']):
468
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
469
+
470
  box_image = draw_box(state['boxes'], grounding_texts, image)
471
 
472
+ if box_image is not None and state.get('inpaint_hw', None):
473
+ inpaint_hw = state['inpaint_hw']
474
+ box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
475
+ original_image = state['original_image'].copy()
476
+ box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
477
 
478
+ return [box_image, new_image_trigger, image_scale, state]
479
 
480
  def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
481
  if task != 'Grounded Inpainting':
482
  sketch_pad_trigger = sketch_pad_trigger + 1
483
  blank_samples = batch_size % 2 if batch_size > 1 else 0
484
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
485
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
486
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
487
+ state = {}
488
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
489
+
490
+ css = """
491
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
492
+ {
493
+ height: var(--height) !important;
494
+ max-height: var(--height) !important;
495
+ min-height: var(--height) !important;
496
+ }
497
+ #paper-info a {
498
+ color:#008AD7;
499
+ text-decoration: none;
500
+ }
501
+ #paper-info a:hover {
502
+ cursor: pointer;
503
+ text-decoration: none;
504
+ }
505
+ """
506
+
507
+ rescale_js = """
508
+ function(x) {
509
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
510
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
511
+ const image_width = root.querySelector('#img2img_image').clientWidth;
512
+ const target_height = parseInt(image_width * image_scale);
513
+ document.body.style.setProperty('--height', `${target_height}px`);
514
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
515
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
516
+ return x;
517
+ }
518
+ """
519
+
520
+ with Blocks(
521
+ css=css,
522
+ analytics_enabled=False,
523
+ title="GLIGen demo",
524
+ ) as main:
525
+ description = """<p style="text-align: center; font-weight: bold;">
526
+ <span style="font-size: 28px">GLIGen: Open-Set Grounded Text-to-Image Generation</span>
527
+ <br>
528
+ <span style="font-size: 18px" id="paper-info">
529
+ [<a href="https://gligen.github.io" target="_blank">Project Page</a>]
530
+ [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>]
531
+ [<a href="https://github.com/gligen/GLIGEN" target="_blank">GitHub</a>]
532
+ </span>
533
+ </p>
534
+ <p>
535
+ To ground concepts of interest with desired spatial specification, please (1) &#9000;&#65039; enter the concept names in <em> Grounding Instruction</em>, and (2) &#128433;&#65039; draw their corresponding bounding boxes one by one using <em> Sketch Pad</em> -- the parsed boxes will be displayed automatically.
536
+ <br>
537
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/gligen/demo?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>
538
+ </p>
 
 
 
 
 
 
 
 
539
  """
540
+ gr.HTML(description)
541
 
542
+ with gr.Row():
543
+ with gr.Column(scale=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  sketch_pad_trigger = gr.Number(value=0, visible=False)
545
  sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
546
  init_white_trigger = gr.Number(value=0, visible=False)
547
  image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
548
  new_image_trigger = gr.Number(value=0, visible=False)
549
 
550
+ task = gr.Radio(
551
+ choices=["Grounded Generation", 'Grounded Inpainting'],
552
+ type="value",
553
+ value="Grounded Generation",
554
+ label="Task",
555
+ )
556
+ language_instruction = gr.Textbox(
557
+ label="Language instruction",
558
+ )
559
+ grounding_instruction = gr.Textbox(
560
+ label="Grounding instruction (Separated by semicolon)",
561
+ )
562
  with gr.Row():
563
+ sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
564
  out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
 
 
 
565
  with gr.Row():
566
  clear_btn = gr.Button(value='Clear')
567
  gen_btn = gr.Button(value='Generate')
 
568
  with gr.Accordion("Advanced Options", open=False):
569
  with gr.Column():
570
+ alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
571
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
572
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
573
+ append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
574
+ use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
575
+ with gr.Row():
576
+ fix_seed = gr.Checkbox(value=True, label="Fixed seed")
577
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
578
+ with gr.Row():
579
+ use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition")
580
+ style_cond_image = gr.Image(type="pil", label="Style Condition", visible=False, interactive=True)
581
+ with gr.Column(scale=4):
582
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
583
+ with gr.Row():
584
+ out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
585
+ out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
586
+ with gr.Row():
587
+ out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
588
+ out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
589
+
590
+ state = gr.State({})
591
+
592
+ class Controller:
593
+ def __init__(self):
594
+ self.calls = 0
595
+ self.tracks = 0
596
+ self.resizes = 0
597
+ self.scales = 0
598
+
599
+ def init_white(self, init_white_trigger):
600
+ self.calls += 1
601
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger+1
602
+
603
+ def change_n_samples(self, n_samples):
604
+ blank_samples = n_samples % 2 if n_samples > 1 else 0
605
+ return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
606
+ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
607
+
608
+ def resize_centercrop(self, state):
609
+ self.resizes += 1
610
+ image = state['original_image'].copy()
611
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
612
+ state['inpaint_hw'] = inpaint_hw
613
+ image_cc = center_crop(image, inpaint_hw)
614
+ # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape)
615
+ return image_cc, state
616
+
617
+ def resize_masked(self, state):
618
+ self.resizes += 1
619
+ image = state['original_image'].copy()
620
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
621
+ state['inpaint_hw'] = inpaint_hw
622
+ image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw)
623
+ state['masked_image'] = image_mask.copy()
624
+ # print(f'mask triggered {self.resizes}')
625
+ return image_mask, state
626
+
627
+ def switch_task_hide_cond(self, task):
628
+ cond = False
629
+ if task == "Grounded Generation":
630
+ cond = True
631
+
632
+ return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None, visible=False), gr.Slider.update(visible=cond), gr.Checkbox.update(visible=(not cond), value=False)
633
+
634
+ controller = Controller()
635
+ main.load(
636
+ lambda x:x+1,
637
+ inputs=sketch_pad_trigger,
638
+ outputs=sketch_pad_trigger,
639
+ queue=False)
640
+ sketch_pad.edit(
641
+ draw,
642
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
643
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
644
+ queue=False,
645
+ )
646
+ grounding_instruction.change(
647
+ draw,
648
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
649
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
650
+ queue=False,
651
+ )
652
+ clear_btn.click(
653
+ clear,
654
+ inputs=[task, sketch_pad_trigger, batch_size, state],
655
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
656
+ queue=False)
657
+ task.change(
658
+ partial(clear, switch_task=True),
659
+ inputs=[task, sketch_pad_trigger, batch_size, state],
660
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
661
+ queue=False)
662
+ sketch_pad_trigger.change(
663
+ controller.init_white,
664
+ inputs=[init_white_trigger],
665
+ outputs=[sketch_pad, image_scale, init_white_trigger],
666
+ queue=False)
667
+ sketch_pad_resize_trigger.change(
668
+ controller.resize_masked,
669
+ inputs=[state],
670
+ outputs=[sketch_pad, state],
671
+ queue=False)
672
+ batch_size.change(
673
+ controller.change_n_samples,
674
+ inputs=[batch_size],
675
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
676
+ queue=False)
677
+ gen_btn.click(
678
+ generate,
679
+ inputs=[
680
+ task, language_instruction, grounding_instruction, sketch_pad,
681
+ alpha_sample, guidance_scale, batch_size,
682
+ fix_seed, rand_seed,
683
+ use_actual_mask,
684
+ append_grounding, style_cond_image,
685
+ state,
686
+ ],
687
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
688
+ queue=True
689
+ )
690
+ sketch_pad_resize_trigger.change(
691
+ None,
692
+ None,
693
+ sketch_pad_resize_trigger,
694
+ _js=rescale_js,
695
+ queue=False)
696
+ init_white_trigger.change(
697
+ None,
698
+ None,
699
+ init_white_trigger,
700
+ _js=rescale_js,
701
+ queue=False)
702
+ use_style_cond.change(
703
+ lambda cond: gr.Image.update(visible=cond),
704
+ use_style_cond,
705
+ style_cond_image,
706
+ queue=False)
707
+ task.change(
708
+ controller.switch_task_hide_cond,
709
+ inputs=task,
710
+ outputs=[use_style_cond, style_cond_image, alpha_sample, use_actual_mask],
711
+ queue=False)
712
+
713
+ with gr.Column():
714
+ gr.Examples(
715
+ examples=[
716
+ [
717
+ "images/blank.png",
718
+ "Grounded Generation",
719
+ "a dog and an apple",
720
+ "a dog;an apple",
721
  ],
722
+ [
723
+ "images/blank.png",
724
+ "Grounded Generation",
725
+ "John Lennon is using a pc",
726
+ "John Lennon;a pc",
727
+ [
728
+ "images/blank.png",
729
+ "Grounded Generation",
730
+ "a painting of a fox sitting in a field at sunrise in the style of Claude Mone",
731
+ "fox;sunrise",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  ],
733
+ ],
734
+ [
735
+ "images/blank.png",
736
+ "Grounded Generation",
737
+ "a beautiful painting of hot dog by studio ghibli, octane render, brilliantly coloured",
738
+ "hot dog",
739
+ ],
740
+ [
741
+ "images/blank.png",
742
+ "Grounded Generation",
743
+ "a sport car, unreal engine, global illumination, ray tracing",
744
+ "a sport car",
745
+ ],
746
+ [
747
+ "images/flower_beach.jpg",
748
+ "Grounded Inpainting",
749
+ "a squirrel and the space needle",
750
+ "a squirrel;the space needle",
751
+ ],
752
+ [
753
+ "images/arg_corgis.jpeg",
754
+ "Grounded Inpainting",
755
+ "a dog and a birthday cake",
756
+ "a dog; a birthday cake",
757
+ ],
758
+ [
759
+ "images/teddy.jpg",
760
+ "Grounded Inpainting",
761
+ "a teddy bear wearing a santa claus red shirt; holding a Christmas gift box on hand",
762
+ "a santa claus shirt; a Christmas gift box",
763
+ ],
764
+ ],
765
+ inputs=[sketch_pad, task, language_instruction, grounding_instruction],
766
+ outputs=None,
767
+ fn=None,
768
+ cache_examples=False,
769
+ )
770
+
771
+ main.queue(concurrency_count=1, api_open=False)
772
+ main.launch(share=False, show_api=False, show_error=True)
773
 
 
 
774
 
 
 
environment.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: loco_gligen_demo
2
+ channels:
3
+ - xformers/label/dev
4
+ - pytorch
5
+ - defaults
6
+ dependencies:
7
+ - python=3.10.8
8
+ - pip=22.2.2
9
+ - cudatoolkit=11.3
10
+ - pytorch=1.12.1
11
+ - torchvision=0.13.1
12
+ - numpy=1.23.1
13
+ - xformers
14
+ - pip:
15
+ - omegaconf==2.1.1
16
+ - albumentations==1.3.0
17
+ - opencv-python
18
+ - imageio==2.9.0
19
+ - imageio-ffmpeg==0.4.2
20
+ - pytorch-lightning==1.4.2
21
+ - test-tube>=0.7.5
22
+ - streamlit==1.12.1
23
+ - einops==0.3.0
24
+ - git+https://github.com/openai/CLIP.git
25
+ - protobuf~=3.20.1
26
+ - torchmetrics==0.6.0
27
+ - transformers==4.19.2
28
+ - kornia==0.6.0
29
+ - gradio==3.16.0
requirements.txt CHANGED
@@ -1,14 +1,18 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
- torch
3
- torchvision==0.14.0
4
- omegaconf==2.2.3
 
5
  opencv-python
6
  imageio==2.9.0
7
- transformers==4.24.0
8
- diffusers==0.7.2
9
- accelerate==0.13.2
10
- scipy==1.9.1
 
11
  git+https://github.com/openai/CLIP.git
12
- hydra-core==1.2.0
13
- tqdm
14
- gradio==3.23.0
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ xformers==0.0.16
4
+ omegaconf==2.1.1
5
+ albumentations==1.3.0
6
  opencv-python
7
  imageio==2.9.0
8
+ imageio-ffmpeg==0.4.2
9
+ pytorch-lightning==1.4.2
10
+ test-tube>=0.7.5
11
+ streamlit==1.17.0
12
+ einops==0.3.0
13
  git+https://github.com/openai/CLIP.git
14
+ protobuf~=3.20.1
15
+ torchmetrics==0.6.0
16
+ transformers==4.19.2
17
+ kornia==0.6.0
18
+ gradio==3.19.1