Pusheen commited on
Commit
f016d3e
1 Parent(s): fff1e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py CHANGED
@@ -186,9 +186,85 @@ def click_on_display(language_instruction, grounding_texts, sketch_pad,
186
 
187
  return gen_images + [state]
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
190
  loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
191
  state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  if 'boxes' not in state:
193
  state['boxes'] = []
194
  boxes = state['boxes']
 
186
 
187
  return gen_images + [state]
188
 
189
+ def Pharse2idx(prompt, phrases):
190
+ def match(prompt_words: List[str], phrase_words: List[str]):
191
+ if prompt_words == phrase_words:
192
+ return True
193
+ for prompt_word, phrase_word in zip(prompt_words, phrase_words):
194
+ if prompt_word != phrase_word and prompt_word != phrase_word+'s' and prompt_word != phrase_word+'es':
195
+ return False
196
+ return True
197
+ phrases = [x.replace('_', ' ') for x in phrases.split('; ')]
198
+ print(phrases)
199
+ object_positions = []
200
+ for punc in [',', '.', ';', ':', '?', '!']:
201
+ prompt = prompt.replace(punc, ' '+punc)
202
+ words = prompt.split()
203
+
204
+ for phrase in phrases:
205
+ phrase_words = phrase.split()
206
+ positions = []
207
+
208
+ for i in range(len(words) - len(phrase_words) + 1):
209
+ if match(words[i:i + len(phrase_words)], phrase_words):
210
+ positions += list(range(i+1, i + len(phrase_words)+1))
211
+ if positions == []:
212
+ print(prompt)
213
+ print(phrases)
214
+ return None
215
+ object_positions.append(positions)
216
+ print(object_positions)
217
+ return object_positions
218
+
219
+
220
  def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
221
  loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
222
  state):
223
+ # language_inst: prompt; grounding_texts: phrases
224
+ if 'boxes' not in state:
225
+ state['boxes'] = []
226
+ boxes = state['boxes']
227
+
228
+ print('grounding texts:', grounding_texts)
229
+ phrases = grounding_texts
230
+
231
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
232
+
233
+ # assert len(boxes) == len(grounding_texts)
234
+ if len(boxes) != len(grounding_texts):
235
+ if len(boxes) < len(grounding_texts):
236
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
237
+ Number of boxes drawn: {}, number of grounding tokens: {}.
238
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
239
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
240
+
241
+ boxes = (np.asarray(boxes) / 512).tolist()
242
+ boxes = [[box] for box in boxes]
243
+ grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
244
+ language_instruction_list = language_instruction.strip('.').split(' ')
245
+ object_positions = []
246
+ for obj in grounding_texts:
247
+ obj_position = []
248
+ for word in obj.split(' '):
249
+ obj_first_index = language_instruction_list.index(word) + 1
250
+ obj_position.append(obj_first_index)
251
+ object_positions.append(obj_position)
252
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
253
+
254
+ object_positions = Pharse2idx(language_instruction, phrases)
255
+
256
+ gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
257
+
258
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
259
+ gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
260
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
261
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
262
+
263
+ return gen_images + [state]
264
+
265
+ def generate_legacy(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
266
+ loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
267
+ state):
268
  if 'boxes' not in state:
269
  state['boxes'] = []
270
  boxes = state['boxes']